/*-GNU-GPL-BEGIN-*
nepim - network pipemeter
Copyright (C) 2005 Everton da Silva Marques

nepim is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.

nepim is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with nepim; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
Boston, MA 02111-1307, USA.
*-GNU-GPL-END-*/


/* $Id: sock.c,v 1.18 2005/08/05 11:12:13 evertonm Exp $ */


#include <assert.h>
#include <netdb.h>
#include <unistd.h>
#include <fcntl.h>
#include <stdio.h>
#include <errno.h>
#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>

#include "sock.h"

/* 
   Solaris
*/
#ifndef SOL_IP
#define SOL_IP 0
#endif
#ifndef SOL_IPV6
#define SOL_IPV6 41
#endif
#ifndef SOL_TCP
#define SOL_TCP 6
#endif

/* 
   Supposedly from <netinet/ip.h> 
*/
#ifndef IP_MTU
#define IP_MTU 14
#endif

#define NEPIM_SOCK_ERR_NONE           (0)
#define NEPIM_SOCK_ERR_UNSPEC         (-1)
#define NEPIM_SOCK_ERR_SOCKET         (-2)
#define NEPIM_SOCK_ERR_BIND           (-3)
#define NEPIM_SOCK_ERR_LISTEN         (-4)
#define NEPIM_SOCK_ERR_CONNECT        (-5)
#define NEPIM_SOCK_ERR_BLOCK          (-6)
#define NEPIM_SOCK_ERR_UNBLOCK        (-7)
#define NEPIM_SOCK_ERR_UNLINGER       (-8)
#define NEPIM_SOCK_ERR_REUSE          (-9)
#define NEPIM_SOCK_ERR_NODELAY        (-10)
#define NEPIM_SOCK_ERR_PMTU           (-11)
#define NEPIM_SOCK_ERR_TTL            (-12)
#define NEPIM_SOCK_ERR_MCAST_TTL      (-13)
#define NEPIM_SOCK_ERR_MCAST_JOIN     (-14)

int nepim_sock_get_port(const struct sockaddr *addr)
{
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } *sa = (void *) addr;

  assert(&(sa->inet.sin_port) == &(sa->inet6.sin6_port));
  assert(sa->inet.sin_port == sa->inet6.sin6_port);

  return ntohs(sa->inet.sin_port);
}

int nepim_socket_block(int sd)
{
  long flags;

  flags = fcntl(sd, F_GETFL, 0);
  if (flags == -1)
    return NEPIM_SOCK_ERR_BLOCK;
  assert(flags >= 0);
  if (fcntl(sd, F_SETFL, flags & ~O_NONBLOCK))
    return NEPIM_SOCK_ERR_BLOCK;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_socket_nonblock(int sd)
{
  long flags;

  flags = fcntl(sd, F_GETFL, 0);
  if (flags == -1)
    return NEPIM_SOCK_ERR_UNBLOCK;
  assert(flags >= 0);
  if (fcntl(sd, F_SETFL, flags | O_NONBLOCK))
    return NEPIM_SOCK_ERR_UNBLOCK;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_socket_pmtu(int sd, int pmtu_mode)
{
  if (pmtu_mode < 0)
    return NEPIM_SOCK_ERR_NONE;

#ifdef IP_MTU_DISCOVER
  return setsockopt(sd, SOL_IP, IP_MTU_DISCOVER, &pmtu_mode, sizeof(pmtu_mode));
#else
  return NEPIM_SOCK_ERR_NONE;
#endif
}

int nepim_socket_ttl(int sd, int ttl)
{
  if (ttl < 0)
    return NEPIM_SOCK_ERR_NONE;

  return setsockopt(sd, SOL_IP, IP_TTL, &ttl, sizeof(ttl));
}

int nepim_socket_mcast_ttl(int sd, int mcast_ttl)
{
  if (mcast_ttl < 0)
    return NEPIM_SOCK_ERR_NONE;

  return setsockopt(sd, SOL_IP, IP_MULTICAST_TTL, &mcast_ttl, sizeof(mcast_ttl));
}

static int create_socket(int domain, int type, int protocol, int pmtu_mode, int ttl)
{
  int sd;
  int result;

  sd = socket(domain, type, protocol);
  if (sd < 0)
    return NEPIM_SOCK_ERR_SOCKET;

  if (type == SOCK_STREAM) {
    result = nepim_socket_tcp_opt(sd);
    if (result) {
      close(sd);
      return result;
    }
  }

  result = nepim_socket_opt(sd, pmtu_mode, ttl);
  if (result) {
    close(sd);
    return result;
  }

  return sd;
}

static int socket_mcast_join(int sd, int family, struct sockaddr *addr, int addr_len)
{
#ifndef NEPIM_SOLARIS
  union {
    struct sockaddr_in inet;
    struct sockaddr_in6 inet6;
  } *sa = (void *) addr;

  switch (family) {
  case PF_INET:
    {
      struct ip_mreqn opt;

      opt.imr_multiaddr = sa->inet.sin_addr;
      opt.imr_address.s_addr = htons(INADDR_ANY);
      opt.imr_ifindex = 0;

      return setsockopt(sd, SOL_IP, IP_ADD_MEMBERSHIP, &opt, sizeof(opt));
    }
    break;
  case PF_INET6:
    {
      struct ipv6_mreq opt;

      assert(sizeof(opt.ipv6mr_multiaddr.s6_addr) == 
	     sizeof(sa->inet6.sin6_addr.s6_addr));
      memcpy(&opt.ipv6mr_multiaddr.s6_addr, 
	     &sa->inet6.sin6_addr.s6_addr, 
	     sizeof(opt.ipv6mr_multiaddr.s6_addr));
      opt.ipv6mr_interface = 0;

      return setsockopt(sd, SOL_IPV6, IPV6_ADD_MEMBERSHIP, &opt, sizeof(opt));
    }
    break;
  default:
    assert(0);
  }

  assert(0);
#endif /* NEPIM_SOLARIS */

  return NEPIM_SOCK_ERR_MCAST_JOIN;
}

int nepim_create_socket(struct sockaddr *addr,
			int addr_len,
			int family,
			int type,
			int protocol,
			int pmtu_mode,
			int ttl,
			int mcast_join)
{
  int sd;
  int result;

  sd = create_socket(family, type, protocol, pmtu_mode, ttl);
  if (sd < 0)
    return sd;

  if (mcast_join) {
    assert(type == SOCK_DGRAM);
    assert(protocol == IPPROTO_UDP);
    
    if (socket_mcast_join(sd, family, addr, addr_len)) {
      close(sd);
      return NEPIM_SOCK_ERR_MCAST_JOIN;
    }
  }

  result = nepim_socket_nonblock(sd);
  if (result) {
    close(sd);
    return result;
  }

#if 0
  fprintf(stderr, "DEBUG %s %s bind: sd=%d af=%d proto=%d addr=%s port=%d\n",
	  __FILE__, __PRETTY_FUNCTION__, sd, 
	  family, protocol,
	  inet_ntoa(((struct sockaddr_in *) addr)->sin_addr),
	  ntohs(((struct sockaddr_in *) addr)->sin_port));
#endif

  if (bind(sd, addr, addr_len)) {
    close(sd);
    return NEPIM_SOCK_ERR_BIND;
  }

#if 0
  fprintf(stderr, "DEBUG %s %s bound: sd=%d af=%d proto=%d addr=%s port=%d\n",
	  __FILE__, __PRETTY_FUNCTION__, sd, 
	  family, protocol,
	  inet_ntoa(((struct sockaddr_in *) addr)->sin_addr),
	  ntohs(((struct sockaddr_in *) addr)->sin_port));
#endif

  return sd;
}

int nepim_create_listener_socket(struct sockaddr *addr,
				 int addr_len,
				 int family,
				 int type,
				 int protocol,
				 int backlog,
				 int pmtu_mode,
				 int ttl)
{
  int sd;

  sd = nepim_create_socket(addr, addr_len, family,
			   type, protocol, pmtu_mode, 
			   ttl, 0);
  if (sd < 0)
    return sd;

  if (listen(sd, backlog)) {
    close(sd);
    return NEPIM_SOCK_ERR_LISTEN;
  }

  return sd;
}

static int unlinger(int sd)
{
  struct linger opt;

  opt.l_onoff = 0;  /* active? */
  opt.l_linger = 0; /* seconds */

  return setsockopt(sd, SOL_SOCKET, SO_LINGER, &opt, sizeof(opt));
}

static int reuse(int sd)
{
  int opt = 1;

  return setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));
}

static int tcp_nodelay(int sd)
{
  int opt = 1;

  return setsockopt(sd, SOL_TCP, TCP_NODELAY, &opt, sizeof(opt));
}

int nepim_socket_opt(int sd, int pmtu_mode, int ttl)
{
  if (unlinger(sd))
    return NEPIM_SOCK_ERR_UNLINGER;

  if (reuse(sd))
    return NEPIM_SOCK_ERR_REUSE;

  if (nepim_socket_pmtu(sd, pmtu_mode))
    return NEPIM_SOCK_ERR_PMTU;

  if (nepim_socket_ttl(sd, ttl))
    return NEPIM_SOCK_ERR_TTL;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_socket_tcp_opt(int sd)
{
  if (tcp_nodelay(sd))
    return NEPIM_SOCK_ERR_NODELAY;

  return NEPIM_SOCK_ERR_NONE;
}

int nepim_connect_client_socket(struct sockaddr *addr,
				int addr_len,
				int family,
				int type,
				int protocol,
				int pmtu_mode,
				int ttl)
{
  int sd;
  int result;

  sd = create_socket(family, type, protocol, pmtu_mode, ttl);
  if (sd < 0)
    return sd;

#ifdef SO_BSDCOMPAT
  /*
   * We don't want Linux ECONNREFUSED on UDP sockets
   */
  if (protocol == IPPROTO_UDP) {
    int one = 1;
    if (setsockopt(sd, SOL_SOCKET, SO_BSDCOMPAT, &one, sizeof(one)))
      return -1;
 }
#endif /* Linux SO_BSDCOMPAT */

  result = nepim_socket_block(sd);
  if (result) {
    close(sd);
    return result;
  }

  fprintf(stderr, 
	  "DEBUG FIXME %s %s slow synchronous connect(port=%d)\n",
	  __FILE__, __PRETTY_FUNCTION__, nepim_sock_get_port(addr));

  if (connect(sd, addr, addr_len)) {
    close(sd);
    return NEPIM_SOCK_ERR_CONNECT;
  }

  result = nepim_socket_nonblock(sd);
  if (result) {
    close(sd);
    return result;
  }

  return sd;
}

int nepim_socket_pmtu_get_mode(int sd)
{
#ifdef IP_MTU_DISCOVER
  int mode;
  socklen_t optlen = sizeof(mode);

  if (getsockopt(sd, SOL_IP, IP_MTU_DISCOVER, &mode, &optlen))
    return NEPIM_SOCK_ERR_PMTU;
  
  assert(optlen == sizeof(mode));

  return mode;
#else
  return NEPIM_SOCK_ERR_PMTU;
#endif
}

int nepim_socket_pmtu_get_mtu(int sd)
{
  int mtu;
  socklen_t optlen = sizeof(mtu);

  if (getsockopt(sd, SOL_IP, IP_MTU, &mtu, &optlen))
    return NEPIM_SOCK_ERR_PMTU;

  assert(optlen == sizeof(mtu));

  return mtu;
}

int nepim_socket_get_ttl(int sd)
{
  int ttl;
  socklen_t optlen = sizeof(ttl);

  if (getsockopt(sd, SOL_IP, IP_TTL, &ttl, &optlen))
    return NEPIM_SOCK_ERR_TTL;

  assert(optlen == sizeof(ttl));

  return ttl;
}

static int socket_mcast_get_ttl(int sd)
{
#ifdef NEPIM_SOLARIS
  unsigned char mcast_ttl;
#else
  int mcast_ttl;
#endif
  socklen_t optlen = sizeof(mcast_ttl);

  if (getsockopt(sd, SOL_IP, IP_MULTICAST_TTL, &mcast_ttl, &optlen))
    return NEPIM_SOCK_ERR_MCAST_TTL;

  assert(optlen == sizeof(mcast_ttl));

  return mcast_ttl;
}

void nepim_sock_show_opt(FILE *out, int sd)
{
  int pmtud_mode;
  int mtu;
  int ttl;
  int mcast_ttl;

  pmtud_mode = nepim_socket_pmtu_get_mode(sd);
  mtu = nepim_socket_pmtu_get_mtu(sd);
  ttl = nepim_socket_get_ttl(sd);
  mcast_ttl = socket_mcast_get_ttl(sd);

  fprintf(out, 
	  "%d: pmtud_mode=%d path_mtu=%d ttl=%d mcast_ttl=%d\n",
	  sd, pmtud_mode, mtu, ttl, mcast_ttl);
}
