/* dns.c
 * helper function for the types implemented 
 * in dns.h: create/destroy... etc
 *
 * (c) NLnet Labs 2004
 * See the file LICENSE for the license
 *
 */

#include <stdlib.h>
#include <string.h>
#include <assert.h>

#include "common.h"

/**
 * Creates an rdata structure 
 * size is the size of the actual uint8_t's going to
 * be stored here.
 */
struct t_rdata *
rdata_create(const uint8_t *str, size_t size)
{
	struct t_rdata *newrdata;
	if (size > MAX_RDLENGTH) {
		warning("size IS WAY TO BIG");
		return NULL;
	}

	newrdata = (struct t_rdata *)xmalloc(sizeof(struct t_rdata));
	newrdata->length = (uint16_t) size;
	newrdata->data = xmalloc(size);
	memcpy(newrdata->data, str, size); /* add the string */
	return newrdata;
}

/**
 * Create a copy of the given rdata
 */
struct t_rdata *
rdata_clone(struct t_rdata *rdata)
{
	struct t_rdata *newrdata;
	assert(rdata != NULL);
	newrdata = xmalloc(sizeof(struct t_rdata));
	newrdata->length = rdata->length;
	newrdata->data = xmalloc(rdata->length);
	memcpy(newrdata->data, rdata->data, rdata->length);
	return newrdata;
}

/**
 * Frees the allocated data for the rdata
 */
void
rdata_destroy(struct t_rdata *rdata)
{
	if (!rdata) {
		error("rdata_destroy called with NULL value");
	}

	if (rdata->data)
		xfree(rdata->data);
	xfree(rdata);
}

/**
 * Returns the string representation of the raw rdata data
 */
char *
rdata2str(struct t_rdata *r)
{
        char * s;
        if (!r)
                return NULL;

        s = xmalloc(r->length + 1);
        memcpy(s, r->data, r->length);
        *(s + r->length) = '\0';
	return s;
}

/**
 * Converts the rdata to an int8_t
 */
uint8_t
rdata2uint8(struct t_rdata *r)
{
	char *str;
	uint8_t t;
	if (!r) 
		return 0;

	str = rdata2str(r);
	t = (uint8_t) atoi(str);
	xfree(str);
	return t;
}

/**
 * Converts the rdata to an int16_t 
 */
uint16_t
rdata2uint16(struct t_rdata *r)
{
	char *str;
	uint16_t t;
	if (!r) 
		return 0;

	str = rdata2str(r);
	t = (uint16_t) atoi(str);
	xfree(str);
	return t;
}

/**
 * Converts the rdata to an int32_t
 */
uint32_t
rdata2uint32(struct t_rdata *r)
{
	char *str;
	uint32_t t;
	if (!r) 
		return 0;

	str = rdata2str(r);
	t = (uint32_t) atoi(str);
	xfree(str);
	return t;
}

/**
 * Converts the rdata to an uint64_t
 * (there is no uint48_t)
 */
uint64_t
rdata2uint64(struct t_rdata *r)
{
	char *str;
	uint64_t t;
	if (!r) 
		return 0;

	str = rdata2str(r);
	t = (uint64_t) atoi(str);
	xfree(str);
	return t;
}

/**
 * Creates an t_rr structure
 */
struct t_rr *
rr_create(struct t_rdata *name, uint16_t type, uint16_t ttl, t_section sec)
{
	struct t_rr *newrr;
	/* what is handy... */
	newrr = (struct t_rr *)xmalloc(sizeof(struct t_rr));
	/* fill in some members */
	newrr->name = rdata_clone(name);
	newrr->type = type;
	newrr->class = CLASS_IN; 
	newrr->section = sec; /* 0 */
	newrr->ttl     = ttl; 
	newrr->rdcount = 0;
	newrr->rdata = xmalloc(sizeof(struct t_rdata)); /* alloc 1 */
	newrr->next = NULL;
	return newrr;
}

/** 
 * Returns the number of rr's in the rrset
 */
size_t
rr_size (struct t_rr *rr)
{
	if (rr)
		return 1 + rr_size(rr->next);
	else 
		return 0;
}

/**
 * Compares the two given rrs (on their wireformat)
 */
int
rr_compare (struct t_rr *rr1, struct t_rr *rr2)
{
	uint8_t *bytes1, *bytes2;
	size_t length1, length2, i;
	int result = 0;
	
	assert(rr1 != NULL);
	assert(rr2 != NULL);
	
	bytes1 = xmalloc(MAX_PACKET);
	bytes2 = xmalloc(MAX_PACKET);

	length1 = rr2wire(rr1, bytes1, 0, MAX_PACKET, SEC_ANSWER, NO_FOLLOW);
	length2 = rr2wire(rr2, bytes2, 0, MAX_PACKET, SEC_ALL, NO_FOLLOW);

	if (length1 < length2) {
		result = -1;
	} else if (length1 > length2) {
		result = 1;
	} else {
		for (i=0; i<length1; i++) {
			if (bytes1[i] < bytes2[i]) {
				result = -1;
				i = length1;
			} else if (bytes1[i] > bytes2[i]) {
				result = 1;
				i = length1;
			}
		}
	}
	xfree(bytes1);
	xfree(bytes2);
	return result;
}

/** 
 * add an RR to an RR, ie chain them together.
 * return the pointer to rr1
 *
 * if rr2 == NULL -> do nothing, return rr1;
 * if rr1 == NULL -> clone rr2 to rr1
 */
struct t_rr *
rr_add_rr(struct t_rr *rr1, struct t_rr *rr2) {

	struct t_rr *tmprr;
	assert(rr1 != rr2);

	if (!rr2) 
		return rr1;

	if (!rr1) 
		return (rr_clone(rr2, NO_FOLLOW));

	tmprr = rr1;
	while (tmprr->next != NULL) {
		assert(tmprr != rr2); /* cannot add the same thing twice */
		tmprr = tmprr->next;
	}
        tmprr->next = rr2;
        return rr1;
}

/**
 * Extends the RR with a rdata
 */
int
rr_add_rdata(struct t_rdata *rdata, struct t_rr *rr)
{
	if (!rdata) return RET_FAIL;
	if (!rr) return RET_FAIL;
	rr->rdata = xrealloc(rr->rdata, (rr->rdcount + 1) * 
			sizeof(struct t_rdata)); /* the new size */
	rr->rdata[rr->rdcount] = rdata;

	rr->rdcount++;
	return RET_SUCCESS;
}

/** 
 * destroy the rr and free all allocated memory
 */
void
rr_destroy(struct t_rr *rr1, unsigned int follow) {
	uint16_t i;

	if (!rr1) {
		error("rr_destroy called with NULL value");
		return;
	}
	
	if (rr1) {
		if (follow) {
			if (rr1->next) {
				rr_destroy(rr1->next, follow);
			}
		}
		rdata_destroy(rr1->name);
		
		for(i=0; i < rr1->rdcount; i++) {
			rdata_destroy(rr1->rdata[i]);
		}
		
		xfree(rr1->rdata);
		xfree(rr1);
	}
}

/**
 * Search an rr(set) for a specific rr
 */
struct t_rr *
rr_get_rr(struct t_rr *rr_in, struct t_rr *rr_what)
{
	struct t_rr *tmprr = rr_in;
	
	assert(rr_in != NULL);
	assert(rr_what != NULL);

	if (rr_in == rr_what || rr_cmp(tmprr, rr_what) == 0) {
		return rr_in;
	}

	while (tmprr != NULL) {
		/* look in the current rr */
		if (tmprr == rr_what || rr_cmp(tmprr, rr_what) == 0) {
			return tmprr;
		}
		tmprr = tmprr->next;
	}
	return NULL;
}

/**
 * Sets the owner name of the rr
 */
int
rr_set_name(struct t_rdata *name, struct t_rr *rr)
{
	if (!rr) return RET_FAIL;
	if (!name) return RET_FAIL;

	if (name->length > MAX_RDLENGTH - 1) 
		return RET_FAIL;

	rr->name = name;
	return RET_SUC;
}

/**
 * Sets the class of an RR - should be handled in create XXX MG 
 */
int
rr_set_class(uint16_t c, struct t_rr *rr)
{
	if (!rr) return RET_FAIL;
	rr->class = c;
	return RET_SUC;
}

/**
 * Creates an empty packet structure
 */
struct t_dpacket *
dpacket_create(void)
{
	struct t_dpacket *newpkt;

	newpkt = (struct t_dpacket *)xmalloc(sizeof(struct t_dpacket));
	/* random must be seeded */
	newpkt->id = (uint16_t)(random() % MAX_ID);
	newpkt->flags.qr = 0;
	newpkt->flags.opcode = 0;
	newpkt->flags.aa = 0;
	newpkt->flags.tc = 0;
	newpkt->flags.rd = 0;
	newpkt->flags.ra = 0;
	newpkt->flags.z = 0;
	newpkt->flags.ad = 0;
	newpkt->flags.cd = 0;
	newpkt->flags.rcode = 0;

	newpkt->count[SEC_QUESTION] = 0;
	newpkt->count[SEC_ANSWER] = 0;
	newpkt->count[SEC_AUTH] = 0;
	newpkt->count[SEC_ADD] = 0;
	newpkt->rrs = NULL;

	newpkt->udppacketsize = 512;
	newpkt->opt.xrcode = 0;
	newpkt->opt.version = 0;
	newpkt->opt.dnssec_ok = 0;
	newpkt->opt.z = 0;

	newpkt->serverip = NULL;
	
	newpkt->querytime = 0;
	newpkt->date = time(NULL);
	return newpkt;
}

void
dpacket_destroy(struct t_dpacket *packet)
{
	if (!packet) {
		error("dpacket_destroy called with NULL value");
		return;
	}
	
	if (packet->serverip) {
		xfree(packet->serverip);
	}
	if (packet->rrs) {
		rr_destroy(packet->rrs, FOLLOW);
	}
	xfree(packet);
}

/**
 * Adds an RR to a packet in a specific section
 */
int
dpacket_add_rr(struct t_rr* rr, t_section sec, struct t_dpacket *pkt)
{
	struct t_rr *tmprr;
	
	assert(rr != NULL);
	assert(pkt != NULL);

	/* OPT rr is placed in SEC_ADD */
	assert(rr->section == sec);

	if (pkt->rrs == NULL)
		pkt->rrs = rr;
	else {
		tmprr = pkt->rrs;
		assert(tmprr != rr);
		while (tmprr->next != NULL) {
			assert(tmprr != rr); /* cannot add the same thing twice */
			tmprr = tmprr->next;
		}
		tmprr->next=rr;
	}
	pkt->count[sec]++;
	return RET_SUCCESS;
}

/**
 * Extracts all the RRs of a specific type and ownername 
 * from a packet.
 * Ie. return an RRset
 * the rrset is taken from section sec.
 * if sec == SEC_ALL, all section are looked through
 * if name == NULL it is not used.
 * returns NULL if nothing found
 */
struct t_rr *
dpacket_get_rrset(struct t_rdata *name, uint16_t type,struct t_dpacket *p, t_section sec)
{
	struct t_rr *rrset = NULL;
	struct t_rr *orig = NULL;
	struct t_rr *r;
	int firstpass = 1; /* yech */
	r = p->rrs;

	assert(p != NULL);

	/* loop through the packet and pick out the right RRs' */
	while(r) {
		if (r->type != type) {
			r = r->next; continue;
		}
		/* if name is NULL don't compare */
		if (name) {
			if (!rdata_cmp(name, r->name) ) {
				r = r->next; continue;
			}
		}
		if (r->section == sec || sec == SEC_ALL) {
			if (r->section == SEC_QUESTION)
				warning("Looking in question section\n");

			if (firstpass) {
				rrset = rr_clone(r, NO_FOLLOW);
				orig = rrset; /* keep the first pointer */
				firstpass = 0;
			} else {
				rrset->next = rr_clone(r, NO_FOLLOW); /* add it */
				rrset = rrset->next; /* chain it */
			}
		}
		r = r->next; 
	}
	return orig;
}

/** 
 * Sets the given TTL on all rrs in the given set
 */
void
rrset_set_ttl(struct t_rr *rrset, uint32_t int32)
{
	if (rrset) {
		rrset->ttl = int32;
		rrset_set_ttl(rrset->next, int32);
	}
}

/**
 * Sorts the given rrset (in order of wireformat, given by rr_compare)
 * The sorting algorithm is selection sort, which can probably be done faster
 * (as comparison is the most complex operation and the number of comparisons
 * is highest on an already ordered list...)
 */
void
rrset_sort(struct t_rr **rrset)
{
	struct t_rr *sorting = *rrset;
	struct t_rr *firstsorted = NULL;
	struct t_rr *lastsorted = NULL;
	/* simple selection sort algorithm */
	
	size_t size = rr_size(*rrset);
	struct t_rr *par[size];
	int i = 0;
	while (sorting != NULL) {
		par[i] = sorting;
		i++;
		sorting = sorting->next;
	}
	/*
         * best input for this algorithm is a reverse ordered list, so we go 
         * backwards through the data (making an ordered list the best input,
         * as we will probably encounter only those)
	 */
	for (i=(int) size-1; i>=0; i--) {
		par[i]->next = NULL;
		if (firstsorted == NULL) {
			firstsorted = par[i];
			lastsorted = par[i];
		} else {
			/* walk through list and add at appropriate pos */
			sorting = firstsorted;
			while (sorting != NULL) {
				if (rr_compare(par[i], sorting) < 0) {
					par[i]->next = sorting;
					if(sorting == firstsorted) {
						firstsorted = par[i];
					} else {
						if (lastsorted != NULL) {
							lastsorted->next = par[i];
						}
					
					}
					sorting = NULL;
				} else {
					if (sorting->next == NULL) {
						sorting->next = par[i];
						lastsorted = par[i];
						sorting = NULL;
					} else {
						lastsorted = sorting;
						sorting = sorting->next;
					}
				}
			}
		}
	
	}
	*rrset = firstsorted;
}

/**
 * Compares 2 rdata's. If they are equal:
 * size1 = size2 and
 * all the data is the same
 * return 1 (true)
 * else return 0 (false)
 */
uint8_t 
rdata_cmp(struct t_rdata *r1, struct t_rdata *r2)
{
	uint16_t i;
	char *rdstr1, *rdstr2;
	
	if (r1->length != r2->length) {
		rdstr1 = rdata2str(r1);
		rdstr2 = rdata2str(r2);
		vverbose("length of rdatas %s and %s not equal: %d and %d\n", rdstr1, rdstr2, r1->length, r2->length);
		xfree(rdstr1);
		xfree(rdstr2);
		return 0;
	}
	
	for (i = 0; i < r1->length ; i++) {
		if (r1->data[i] != r2->data[i]) {
			/* printf("%c != %c\n", r1->data[i], r2->data[i]);
			 * debug */
			return 0;
		}
	}
	/* wow, there actual are equal */
	return 1;
}

/**
 * Copies an rr to a newly allocated rr
 * 
 */
struct t_rr *
rr_clone(struct t_rr *in, unsigned int follow)
{
	struct t_rr *out = NULL;
	uint16_t i;

	if (in) {
		out = rr_create(in->name, in->type, in->ttl, in->section);
		for (i = 0; i < in->rdcount; i++) 
			rr_add_rdata(rdata_clone(in->rdata[i]), out);

		out->class = in->class;
		/* important, does NOT copy the chain, unless FOLLOW is set */
		if (follow == NO_FOLLOW)
			out->next = NULL; 
	}
	
	return out;
}

/**
 * Checkd if the packets would be needing a 
 * OPT pseudo rr
 * return 1 if it does
 * 0 otherwise
 */
uint8_t 
packet_is_opt(struct t_dpacket *p)
{
	if (GET_UDPSIZE(p) > 512)
		return 1;

	if (GET_DNSSEC(p) == 1)
		return 1;

	/* no opt */
	return 0;
}

/**
 * Counts the labels in a rr->name
 */
uint8_t 
label_cnt(struct t_rr *rr)
{
	size_t i;
	uint8_t cnt = 0;
	for (i = 0; i < (rr->name->length); ++i) {
		if (rr->name->data[i] == '.') cnt++;
	}
	return cnt;
}

/**
 * Chops the left 'labels' most labels from a rr->name
 */
struct t_rr *
chop_labels_left(struct t_rr *rr, uint8_t labels)
{
	uint8_t lab_cnt = 0; size_t i;
	struct t_rdata *name;

	if (labels == 0)
		return rr_clone(rr, NO_FOLLOW);

	for (i = 0; i < (rr->name->length); ++i) {
		if (rr->name->data[i] == '.') {
			lab_cnt++;
			continue;
		}
		if (lab_cnt < labels) 
			continue;
		break;
	}
	/* remove everything up and to i, first clone the name */
	name = rdata_clone(rr->name);
	name->length = rr->name->length - i; /* new length */
	memcpy(name->data, rr->name->data + i, name->length);
	/* printf("new rdata name %s", rdata2str(name)); */
	return rr_create(name, rr->type, rr->ttl, rr->section);
}

/**
 * Compares 2 rrs to each other
 * The rr's are converted to wire-format and then compared
 * returns 1 of the first one is 'smaller'
 * TTL does not count for equality, so we replace TTL by the default value
 */
int
rr_cmp(struct t_rr *aa, struct t_rr *bb)
{
	struct t_rr *a = rr_clone(aa, NO_FOLLOW);
	struct t_rr *b = rr_clone(bb, NO_FOLLOW);
	uint8_t *wirea, *wireb;
	size_t sizea, sizeb, i;
	int result = 0;
	
	a->ttl = DEF_TTL;
	b->ttl = DEF_TTL;
	
	wirea = xmalloc(MAX_PACKET);
	wireb = xmalloc(MAX_PACKET);
	
	sizea = rr2wire(a, wirea, 0, MAX_PACKET, SEC_ALL, NO_FOLLOW);
	sizeb = rr2wire(b, wireb, 0, MAX_PACKET,  SEC_ALL, NO_FOLLOW);
	
	if (sizea < sizeb) {
		result = -1;
	} else if (sizea > sizeb) {
		result = 1;
	} else {
		for (i=0; i<sizea; i++) {
			if (wirea[i] < wireb[i]) {
				result = -1;
			} else if (wirea[i] > wireb[i]) {
				result = 1;
			}
		}
	}
	
	rr_destroy(a, NO_FOLLOW);
	rr_destroy(b, NO_FOLLOW);
	xfree(wirea);
	xfree(wireb);
	return result;
}
