/* Copyright (C) 1999, 2000, 2001 Simon Patarin, INRIA

This file is part of Pandora, the Flexible Monitoring Platform.

Pandora is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2, or (at your option)
any later version.

Pandora is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Pandora; see the file COPYING.  If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
Boston, MA 02111-1307, USA.  */

#include <libpandora/global.h>

#include <iostream>
#include <iomanip>
#include <pandora_components/dnspacket.h>
#include <pandora_components/udppacket.h>
#include <pandora_components/ippacket.h>
#include <libpandora/pandorakey.h>
#include <libpandora/algo_funcs.h>
#include <libpandora/timeval.h>
#include <libpandora/serialize.h>
#include <libpandora/error.h>

packet_export(DNSPacket, UDPPacket);

DNSPacket::DNSPacket(UDPPacket* udpp) 
  : type(undefined), length(0)
{
  locatePacket(IPPacket, ipp, udpp);
  if (ipp == NULL) return;

  u_int hlen = sizeof(HEADER);

  if (ipp->dlength() < hlen) {
    pandora_warning("truncated DNS packet");
    cleanPacket(udpp);
    return;
  }
  
  length = udpp->length - hlen;

  HEADER *dns = (HEADER *)(ipp->data());

  bom = (u_char *) ipp->data();
  eom = bom + ipp->dlength();

  id = dns->id;
  opcode = dns->opcode;
  aa = dns->aa;
  tc = dns->tc;
  rd = dns->rd;
  ra = dns->ra;
  rcode = dns->rcode;
  
  qd.init(ntohs(dns->qdcount));
  an.init(ntohs(dns->ancount));
  ns.init(ntohs(dns->nscount));
  ar.init(ntohs(dns->arcount));

  (ipp->_data).move(HFIXEDSZ);
  
  for (int i = 0; i < qd.size(); ++i) {
    dns_query_t *q = &qd[i];
    int n = parseQuery(q, udpp);
    if (n < 0) goto finished;
  }
  
  for (int i = 0; i < an.size(); ++i) {
    rrecord_t *r = &an[i];
    int n = parseRRecord(r, udpp);
    if (n < 0) goto finished;
  }

  for (int i = 0; i < ns.size(); ++i) {
    rrecord_t *r = &ns[i];
    int n = parseRRecord(r, udpp);
    if (n < 0) goto finished;
  }

  for (int i = 0; i < ar.size(); ++i) {
    rrecord_t *r = &ar[i];
    int n = parseRRecord(r, udpp);
    if (n < 0) goto finished;
  }

  type = (dns->qr == 0) ? request : response;
  timeStamp = udpp->timeStamp; 

 finished:
  packetSetUp(udpp);
  return;
}

DNSPacket::DNSPacket(const DNSPacket& x) : 
  Packet(x), type(x.type),
  id(x.id), opcode(x.opcode), aa(x.aa), tc(x.tc), rd(x.rd), ra(x.ra),
  rcode(x.rcode), length(x.length),
  qd(x.qd), an(x.an), ns(x.ns), ar(x.ar),
  bom(NULL), eom(NULL) 
{
}

DNSPacket& DNSPacket::operator= (const DNSPacket& x) 
{
  Packet::operator=(x);
  type = x.type;
  id = x.id; opcode = x.opcode; aa = x.aa; 
  tc = x.tc; rd = x.rd; ra = x.ra;
  rcode = x.rcode; length = x.length; 
  qd = x.qd; an = x.an; ns = x.ns; ar = x.ar;

  return *this;
}

void DNSPacket::print(ostream *f)
{
  locatePacket(UDPPacket,	udpp,	this);
  locatePacket(IPPacket,	ipp,	udpp);
  *f << timeStamp << '\t'
     << "[dns] "
     << ipp->src << ':' << ntohs(udpp->sport) << ' ' 
     << ipp->dst << ':' << ntohs(udpp->dport) << ' ';

  *f << (int) type << ' ' << id << ' ';

  if (qd.size() > 0) *f << (int)qd[0].type << " " << qd[0].name << ' ';
  else *f << "? ";
  
  if (an.size() > 0) {
    for (int i=0; i < an.size(); ++i) {
      printRR(f, an[i]);
        *f << ' ';
    }
  } else *f << "- ";    
  
  if (ns.size() > 0) {
    for (int i=0; i < ns.size(); ++i) {
      printRR(f, ns[i]);
        *f << ' ';
    }
  }
  else *f << "- ";    
  
  if (ar.size() > 0) {
    for (int i=0; i < ar.size(); ++i) {
      printRR(f, ar[i]);
      *f << ' ';
    }
  } else *f << "- ";
  
  
  *f << endl;
}

int DNSPacket::expandName(char *name, u_char *msg)
{
  int n = dn_expand(bom, eom, msg, name, MAXDNAME);
  return n;
}

int DNSPacket::parseQuery(dns_query_t *q, UDPPacket *udpp)
{
  locatePacket(IPPacket, ipp, udpp);
  if (ipp == NULL) return -2;

  int n = expandName(q->name, (u_char *) ipp->data());
  if (n < 0) return n;
  if (ipp->dlength() < (size_t)n+4) return -2;
  (ipp->_data).move(n);
  q->type = (ipp->_data).getShort();
  q->_class = (ipp->_data).getShort();
  return n;
}

int DNSPacket::parseRRecord(rrecord_t *r, UDPPacket *udpp)
{
  locatePacket(IPPacket, ipp, udpp);
  if (ipp == NULL) return -2;

  int n = expandName(r->name, (u_char *) ipp->data());
  if (n < 0) return n;
  if (ipp->dlength() < (size_t)n+10) return -2;
  (ipp->_data).move(n);
  r->type = 	(ipp->_data).getShort();
  r->_class = 	(ipp->_data).getShort();
  r->ttl = 	(ipp->_data).getLong();
  r->length = 	(ipp->_data).getShort();
  if (ipp->dlength() < r->length) return -3;
  n+= (2+2+4+2+r->length);
  switch(r->type) {
  case T_NS:
    (void)expandName(r->t_ns.nameserver, (u_char *) ipp->data());
    break;
    
  case T_PTR:
    (void)expandName(r->t_ptr.domain, (u_char *) ipp->data());
    break;

  case T_CNAME:
    (void)expandName(r->t_cname.canonical, (u_char *) ipp->data());
    break;

  case T_MX:
    r->t_mx.preference = (ipp->_data).getShort();
    r->length -= 2;
    (void)expandName(r->t_mx.mxhost, (u_char *) ipp->data());
    break;

  default:
    memcpy((u_char *)&(r->_data), (u_char *) ipp->data(), r->length);
    break;
  }

  (ipp->_data).move(r->length);
  return n;
}

void DNSPacket::printRR(ostream *f, const rrecord_t &r)
{
  char buf[MAXDNAME+1];
  int n=0;

  *f << '[';
  switch(r.type) {
  case T_A:
    if (r._class == C_IN || r._class == C_HS) {
      *f << "A " << intoa(r.t_a.address);
    }
    break;

  case T_NS:
    *f << "NS " << r.t_ns.nameserver;
    break;

  case T_PTR:
    *f << "PTR " << r.t_ptr.domain;
    break;

  case T_CNAME:
    *f << "CNAME " << r.t_cname.canonical;
    break;

  case T_MX:
    *f << "MX " << r.t_mx.mxhost << " (" << ntohs(r.t_mx.preference) << ")";
    break;

  default:
    *f <<  '?' << (int) r.type;
    break;
  }
  *f << ']';
}

size_t DNSPacket::write(char *str, size_t maxlen, int level)
{
  size_t count = 0;

  serialVar(type);
  serialVar(id);
  serialVar(opcode);
  serialVar(aa);
  serialVar(tc);
  serialVar(rd);
  serialVar(ra);
  serialVar(rcode);
  serialVar(qd);
  serialVar(an);
  serialVar(ns);
  serialVar(ar);
  serialVar(length);

  return count;
}

size_t DNSPacket::read(const char *str, int level)
{
  size_t count = 0;

  unserialVar(type);
  unserialVar(id);
  unserialVar(opcode);
  unserialVar(aa);
  unserialVar(tc);
  unserialVar(rd);
  unserialVar(ra);
  unserialVar(rcode);
  unserialVar(qd);
  unserialVar(an);
  unserialVar(ns);
  unserialVar(ar);
  unserialVar(length);

  return count;
}

void serialize(char *str, size_t &count, const size_t maxlen,  
	       dns_query_t *var)
{
  char *tmp = var->name;
  serialVar(tmp);  
  serialVar(var->type);
  serialVar(var->_class);
}

void unserialize(const char *str, size_t &count, dns_query_t *var)
{
  char *tmp = var->name;
  unserialVar(tmp);
  unserialVar(var->type);
  unserialVar(var->_class);
}

void serialize(char *str, size_t &count, const size_t maxlen,  
	       rrecord_t *var)
{
  char *tmp = var->name;
  serialVar(tmp);  
  serialVar(var->type);
  serialVar(var->_class);
  serialVar(var->ttl);
  serialVar(var->length);

  switch(var->type) {
  case T_A:	serialVar((var->t_a).address);			break;
  case T_NS:	tmp = (var->t_ns).nameserver; serialVar(tmp);	break;
  case T_PTR:	tmp = (var->t_ptr).domain; serialVar(tmp);	break;
  case T_CNAME:	tmp = (var->t_cname).canonical; serialVar(tmp);	break;
  case T_MX:	tmp = (var->t_mx).mxhost; serialVar(tmp);	
    		serialVar((var->t_mx).preference);		break;
  default:	break;
  }
}

void unserialize(const char *str, size_t &count, rrecord_t *var)
{
  char *tmp = var->name;
  unserialVar(tmp);  
  unserialVar(var->type);
  unserialVar(var->_class);
  unserialVar(var->ttl);
  unserialVar(var->length);

  switch(var->type) {
  case T_A:	unserialVar((var->t_a).address);		break;
  case T_NS:	tmp = (var->t_ns).nameserver; unserialVar(tmp);	break;
  case T_PTR:	tmp = (var->t_ptr).domain; unserialVar(tmp);	break;
  case T_CNAME:	tmp = (var->t_cname).canonical;unserialVar(tmp);break;
  case T_MX:	tmp = (var->t_mx).mxhost; unserialVar(tmp);	
    		unserialVar((var->t_mx).preference);		break;
  default:	break;
  }
}

extern_pandora(algo, bool, dnsid, (Packet *pkt, PandoraKey *k))
{
  locatePacket0(DNSPacket, dnsp, pkt);
  if (dnsp == NULL) return false;

  k->set(dnsp->id);
  
  return true;
}		 

extern_pandora(algo, bool, dnsreq, (Packet *pkt, PandoraKey *k))
{
  locatePacket0(DNSPacket, 	dnsp,  pkt);
  locatePacket(IPPacket, 	ipp,   dnsp);
  if (ipp == NULL) return false;

  switch ((dnsp->an[0]).type) {
  case T_A:      k->set((ipp->src).s_addr, (dnsp->an[0]).t_a.address); break;
  case T_CNAME:  k->set((ipp->src).s_addr, (dnsp->an[1]).t_a.address); break;
  default:       k->set((ipp->src).s_addr, (ipp->dst).s_addr); break;
  }

  return true;
}
