/*

Copyright (C) 2003, 2004 Christian Kreibich <christian@whoop.org>.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies of the Software and its documentation and acknowledgment shall be
given in the documentation and software packages that this Software was
used.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

*/
#ifdef HAVE_CONFIG_H
#  include <config.h>
#endif

#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>

#include <protocols/ip/libnd_ip.h>
#include <protocols/tcp/libnd_tcp.h>
#include <protocols/udp/libnd_udp.h>
#include "libnd_conntrack.h"

#define LND_CONN_TABLE_SLOTS 8009

typedef struct lnd_ct_item LND_CTItem;

typedef gboolean (*LND_CTCallback)(LND_CTItem *item, void *user_data);

struct lnd_ct_item
{
  LND_ConnID  *conn;

  /* Double-linked list for hashtable overflow chain */
  LND_CTItem  *slot_prev;
  LND_CTItem  *slot_next;

  /* Double-linked list for ageing list */
  LND_CTItem  *age_prev;
  LND_CTItem  *age_next;
};

struct lnd_conn_table
{
  LND_ConnTablePolicy policy;
  guint               size;
  LND_CTItem        **slots;
  LND_CTItem         *age;
  LND_CTItem         *age_last;
};


/* GHashFunc to hash connections: */
static guint
conn_hash(const LND_ConnID *conn)
{
  switch (conn->proto)
    {
    case IPPROTO_TCP:
    case IPPROTO_UDP:
      return
	((guint32) conn->proto) ^
	((guint32) conn->sport ^ (guint32) conn->ip_src.s_addr) ^
	((guint32) conn->dport ^ (guint32) conn->ip_dst.s_addr);
      
    default:
      return
	((guint32) conn->proto) ^ ((guint32) conn->ip_src.s_addr) ^ ((guint32) conn->ip_dst.s_addr);
    }

  /* not reached */
  return 0;
}

/* GCompareFunc to compare our connections: */
static gint 
conn_compare(const LND_ConnTable *ct, const LND_ConnID *c_old, const LND_ConnID *c_new)
{
  if (c_old->proto != c_new->proto)
    return FALSE;

  switch (c_old->proto)
    {
    case IPPROTO_TCP:
      /* Terminated connections stay in the hashtable but
       * are never matched successfully.
       */
      if (ct->policy == LND_CONN_TABLE_IGNORE_DEAD)
	{
	  if (libnd_tcpconn_is_dead((LND_TCPConn *) c_old, NULL) ||
	      libnd_tcpconn_is_dead((LND_TCPConn *) c_new, NULL))
	    {
	      return FALSE;
	    }
	}
      
      /* Now we fall through to the UDP case because it
       * does the same comparison operations.
       */
    case IPPROTO_UDP:

      /* FIXME some kind of timeout mechanism would be
       * nice here.
       */

      /*
      printf("Comparing %u %u %u %u to %u %u %u %u\n",
	     c_old->ip_src.s_addr, ntohs(c_old->sport),
	     c_old->ip_dst.s_addr, ntohs(c_old->dport),
	     c_new->ip_src.s_addr, ntohs(c_new->sport),
	     c_new->ip_dst.s_addr, ntohs(c_new->dport));
      */

      /* Direct match: */
      if (c_old->ip_src.s_addr == c_new->ip_src.s_addr &&
	  c_old->ip_dst.s_addr == c_new->ip_dst.s_addr &&
	  c_old->sport == c_new->sport                 &&
	  c_old->dport == c_new->dport)
	{
	  return TRUE;
	}
      
      /* Reverse direction match -- same flow! */
      if (c_old->ip_src.s_addr == c_new->ip_dst.s_addr &&
	  c_old->ip_dst.s_addr == c_new->ip_src.s_addr &&
	  c_old->sport == c_new->dport                 &&
	  c_old->dport == c_new->sport)
	{
	  return TRUE;
	}
      break;
      
    default: 
      /* Everything else: protocol, IP src, and IP dst must match */
      
      if (c_old->ip_src.s_addr == c_new->ip_src.s_addr &&
	  c_old->ip_dst.s_addr == c_new->ip_dst.s_addr)
	{
	  return TRUE;
	}
      
      if (c_old->ip_src.s_addr == c_new->ip_dst.s_addr &&
	  c_old->ip_dst.s_addr == c_new->ip_src.s_addr)
	{
	  return TRUE;
	}      
    }
  
  return FALSE;
}


static LND_CTItem *
conn_table_item_new(LND_ConnID *conn)
{
  LND_CTItem *item;

  if (! (item = g_new0(LND_CTItem, 1)))
    {
      D(("Out of memory.\n"));
      return NULL;
    }

  item->conn = conn;

  return item;
}

static void
conn_table_item_free(LND_CTItem *item)
{
  if (!item)
    return;

  libnd_conn_free(item->conn);
  g_free(item);
}


static gboolean
conn_table_item_free_cb(LND_CTItem *item, LND_ConnTable *ct)
{
  conn_table_item_free(item);
  ct->size--;

  return TRUE;
}


static void
conn_table_foreach(LND_ConnTable *ct, LND_CTCallback callback, void *user_data)
{
  LND_CTItem *item, *item_next;

  if (!ct || !callback)
    return;

  for (item = ct->age; item; )
    {
      item_next = item->age_next;
      
      if (! callback(item, user_data))
	return;
      
      item = item_next;
    }
}


static LND_CTItem *
conn_table_find_item(LND_ConnTable *ct, LND_ConnID *conn, guint *slot_result)
{
  LND_CTItem *item;
  guint slot;

  if (!ct || !conn)
    return NULL;

  slot = (conn_hash(conn) % LND_CONN_TABLE_SLOTS);

  for (item = ct->slots[slot]; item; item = item->slot_next)
    {
      if (conn_compare(ct, item->conn, conn))
	{
	  if (slot_result)
	    *slot_result = slot;
	  return item;
	}
    }

  return NULL;
}


static LND_CTItem *
conn_table_remove(LND_ConnTable *ct, LND_ConnID *conn)
{
  guint slot;
  LND_CTItem *item;

  if (! (item = conn_table_find_item(ct, conn, &slot)))
    return NULL;

  /* remove item from slot */

  if (item->slot_prev)
    item->slot_prev->slot_next = item->slot_next;

  if (item->slot_next)
    item->slot_next->slot_prev = item->slot_prev;
  
  if (item == ct->slots[slot])
    ct->slots[slot] = item->slot_next;

  /* remove item from age list */

  if (item->age_prev)
    item->age_prev->age_next = item->age_next;

  if (item->age_next)
    item->age_next->age_prev = item->age_prev;

  if (item == ct->age)
    ct->age = item->age_next;

  if (item == ct->age_last)
    ct->age_last = item->age_prev;

  ct->size--;

  return item;
}


LND_ConnTable *
libnd_conn_table_new(LND_ConnTablePolicy policy)
{
  LND_ConnTable *ct;

  if (! (ct = g_new0(LND_ConnTable, 1)))
    {
      D(("Out of memory.\n"));
      return NULL;
    }

  if (! (ct->slots = g_new0(LND_CTItem*, LND_CONN_TABLE_SLOTS)))
    {
      D(("Out of memory.\n"));
      g_free(ct);
      return NULL;
    }

  ct->policy = policy;

  return ct;
}


void            
libnd_conn_table_free(LND_ConnTable *ct)
{
  if (!ct)
    return;

  ct->policy = LND_CONN_TABLE_INCLUDE_DEAD;

  conn_table_foreach(ct, (LND_CTCallback) conn_table_item_free_cb, ct);
  g_free(ct->slots);
  
  /* The age list just got freed implicitly since it contains the
   * same elements as the hashtable slots.
   */

  g_free(ct);
}


LND_ConnTablePolicy
libnd_conn_table_get_policy(const LND_ConnTable *ct)
{
  if (!ct)
    return LND_CONN_TABLE_NA;

  return ct->policy;
}


void            
libnd_conn_table_set_policy(LND_ConnTable *ct, LND_ConnTablePolicy policy)
{
  if (!ct || policy == LND_CONN_TABLE_NA)
    return;

  ct->policy = policy;
}


void
libnd_conn_table_add(LND_ConnTable *ct, LND_ConnID *conn)
{
  guint slot;
  LND_CTItem *item;

  if (!ct || !conn)
    return;

  if (! (item = conn_table_item_new(conn)))
    return;

  slot = (conn_hash(conn) % LND_CONN_TABLE_SLOTS);
  
  item->slot_next = ct->slots[slot];
  if (ct->slots[slot])
    ct->slots[slot]->slot_prev = item;
  ct->slots[slot] = item;

  if (! ct->age_last)
    ct->age_last = item;
  
  item->age_next = ct->age;
  if (ct->age)
    ct->age->age_prev = item;
  ct->age = item;

  ct->size++;
}


LND_ConnID    *
libnd_conn_table_lookup(LND_ConnTable *ct,
			const LND_Packet *packet)
{
  LND_CTItem *item = NULL;
  LND_ProtoData *proto_data;
  struct ip *iphdr;
  struct tcphdr *tcphdr;
  struct udphdr *udphdr;
  LND_ConnID conn;
  LND_ConnID *result = NULL;

  if (!ct || !packet)
    return NULL;

  memset(&conn, 0, sizeof(LND_ConnID));

  if (! (proto_data = libnd_packet_get_proto_data(packet, libnd_ip_get(), 0)))
    return NULL;

  iphdr = (struct ip*) proto_data->data;
  conn.proto  = iphdr->ip_p;
  conn.ip_src = iphdr->ip_src;
  conn.ip_dst = iphdr->ip_dst;

  switch (iphdr->ip_p)
    {
    case IPPROTO_TCP:
      /* Try to get TCP headers -- this might fail if it's a fragment.
       * In that case we fall back to pure IP connections. Same applies
       * to UDP connections below.
       */
      if ( (proto_data = libnd_packet_get_proto_data(packet, libnd_tcp_get(), 0)))
	{
	  tcphdr = (struct tcphdr*) proto_data->data;
	  conn.sport = tcphdr->th_sport; 
	  conn.dport = tcphdr->th_dport; 
	}      
      break;

    case IPPROTO_UDP:
      if ( (proto_data = libnd_packet_get_proto_data(packet, libnd_udp_get(), 0)))
	{
	  udphdr = (struct udphdr*) proto_data->data;
	  conn.sport = udphdr->uh_sport; 
	  conn.dport = udphdr->uh_dport; 
	}
      break;

    default:
      /* nothing */
      break;
    }

  /* Our lookups implement move-to-front, so a lookup is
   * effectively a removal and a re-insertion into the table.
   * The ageing list gets updated accordingly in the
   * process.
   */
  if ( (item = conn_table_remove(ct, &conn)))
    {
      result = item->conn;
      g_free(item);
      libnd_conn_table_add(ct, result);
    }

  if (! result)
    return NULL;

  /* Check whether the connection is dead (updating its timers),
   * and if that is the case ignore this one if requested.
   */
  if (libnd_conn_is_dead(result, packet) &&
      ct->policy == LND_CONN_TABLE_IGNORE_DEAD)
    return NULL;

  return result;
}


LND_ConnID *
libnd_conn_table_remove(LND_ConnTable *ct, LND_ConnID *conn)
{
  LND_CTItem *item;
  LND_ConnID *result;

  if (!ct || !conn)
    {
      D(("Invalid input.\n"));
      return NULL;
    }

  if (! (item = conn_table_remove(ct, conn)))
    return NULL;

  result = item->conn;
  g_free(item);

  return result;
}


LND_ConnID *
libnd_conn_table_get_oldest(LND_ConnTable *ct)
{
  LND_CTItem *item;

  if (!ct)
    return NULL;

  /* Go through age list starting from oldest entry, and
   * pick first legitimate item.
   */
  for (item = ct->age_last; item; item = item->age_prev)
    {
      if (libnd_conn_is_dead(item->conn, NULL) &&
	  ct->policy == LND_CONN_TABLE_IGNORE_DEAD)
	continue;
      
      return item->conn;
    }

  return NULL;
}


LND_ConnID *
libnd_conn_table_get_youngest(LND_ConnTable *ct)
{
  LND_CTItem *item;

  if (!ct)
    return NULL;

  /* Go through age list starting from oldest entry, and
   * pick first legitimate item.
   */
  for (item = ct->age; item; item = item->age_next)
    {
      if (libnd_conn_is_dead(item->conn, NULL) &&
	  ct->policy == LND_CONN_TABLE_IGNORE_DEAD)
	continue;
      
      return item->conn;
    }

  return NULL;
}


guint           
libnd_conn_table_size(LND_ConnTable *ct)
{
  if (!ct)
    return 0;

  return ct->size;
}


struct conn_table_cb
{
  LND_ConnTable   *ct;
  LND_ConnFunc     func;
  void            *user_data;
};

static gboolean
conn_table_foreach_cb(LND_CTItem *item, void *user_data)
{
  struct conn_table_cb *cb_data = user_data;

  if (cb_data->ct->policy == LND_CONN_TABLE_IGNORE_DEAD &&
      libnd_conn_is_dead(item->conn, NULL))
    return TRUE;
  
  return cb_data->func(item->conn, cb_data->user_data);
}

void
libnd_conn_table_foreach(LND_ConnTable *ct, LND_ConnFunc func, void *user_data)
{
  struct conn_table_cb cb_data;

  if (!ct || !func)
    return;

  cb_data.ct = ct;
  cb_data.func = func;
  cb_data.user_data = user_data;

  conn_table_foreach(ct, conn_table_foreach_cb, &cb_data);
}


void
libnd_conn_table_foreach_oldest(LND_ConnTable *ct, LND_ConnFunc func, void *user_data)
{
  LND_CTItem *item;

  if (!ct || !func)
    return;

  for (item = ct->age_last; item; item = item->age_prev)
    {
      if (libnd_conn_is_dead(item->conn, NULL) &&
	  ct->policy == LND_CONN_TABLE_IGNORE_DEAD)
	continue;
      
      if (! func(item->conn, user_data))
	return;
    }
}
