/* wire.c
 *
 * Converts from and to wireformat
 *
 * (c) NLnet Labs, 2004
 *
 * See the file LICENSE for the license
 *
 */

#include "common.h"

/**
 * Converts the given domain name to wire format and places it in the buffer at
 * the specified offset returns the total length of the wire data
 */
size_t
name2wire(struct t_rdata *name, uint8_t *buf, unsigned long offset, const size_t buf_len)
{

	uint16_t lastpos = 0;
	unsigned long cur_offset = offset;
	char *n;
	
	uint16_t si, ti;
	uint16_t i;

	for (i=0; i < name->length; i++) {
		if (name->data[i] == '.' && (name->data[i-1] != '\\')) {
			if (i - lastpos == 0) {
				n = rdata2str(name);
				if (strncmp(".", n, 2) == 0) {
					// n/m, root is exception
					xfree(n);
				} else {
					error("%s is not a valid name (empty label)", n);
					xfree(n);
					exit(EXIT_FAILURE);
				}
			}
			ti = 0;

			if (0 == i - lastpos) {
				/* root label, handle seperately */
				if (cur_offset + 1 <= buf_len) {
					buf[cur_offset + 1] = name->data[lastpos];
					ti++;
				} else {
					fprintf(stderr, "Buffer too small in name2wire\n");
					exit(1);
				}
				break;
			}
			
			for (si = 0; si < i - lastpos; si++) {
				if (name->data[lastpos + si] == '\\') {
					if (
						isdigit((int) name->data[lastpos + si +1]) &&
						isdigit((int) name->data[lastpos+ si + 2]) &&
						isdigit((int) name->data[lastpos+ si + 3])
					   ) {
					   	/* it's an octal repr (fi \255) */
					   	n = xmalloc(4);
					   	memcpy(n, &name->data[lastpos+ si + 1], 3);
					   	n[3] = '\0';
					   	if (cur_offset + ti + 1 <= buf_len) {
					   		buf[cur_offset + ti +1] = (uint8_t) atoi(n);
						} else {
							fprintf(stderr, "Buffer too small in name2wire\n");
							exit(1);
						}
					   	xfree(n);
					   	si += 4;
					} else {
						/* it's 'just' an escaped char */
						si++;
					   	if (cur_offset + ti + 1 <= buf_len) {
							buf[cur_offset+ ti + 1] = name->data[lastpos + si];
						} else {
							fprintf(stderr, "Buffer too small in name2wire\n");
							exit(1);
						}
					}
				} else {
					if (cur_offset + ti + 1 <= buf_len) {
						buf[cur_offset + ti + 1] = name->data[lastpos + si];
					} else {
						fprintf(stderr, "Buffer too small in name2wire\n");
						exit(1);
					}
				}
				ti++;
			}
			//memcpy(&buf[cur_offset+1], &name->data[lastpos], buf[cur_offset]);
			if (cur_offset <= buf_len) {
				buf[cur_offset] = ti;
			} else {
				fprintf(stderr, "Buffer too small in name2wire\n");
				exit(1);
			}
			cur_offset += buf[cur_offset] + 1;
			lastpos = i + 1;
		}
	}
	if (cur_offset <= buf_len) {
		buf[cur_offset] = 0;
	} else {
		fprintf(stderr, "Buffer too small in name2wire\n");
		exit(1);
	}
	cur_offset++;
	
	return cur_offset - offset;
}

/**
 * Converts the given rrset to wire format and places it in the buffer at
 * the given offset
 */
size_t
rrset2wire(struct t_rr *rr, uint8_t *buf, unsigned long offset, const size_t buf_len)
{
	size_t len;
	len = rr2wire(rr, buf, offset, buf_len, SEC_ALL, FOLLOW);
	return len;
}

/**
 * Converts the given t_rr to wire format and places it in the buffer at the specified offset
 * returns the total length of the wire data
 * Only add it if the section equals 'section'
 *
 * if you get a segfault on rr->section in rr2wire, you most probably have forgotten to to create_packet before
 */
size_t
rr2wire(struct t_rr *rr, uint8_t *buf, unsigned long offset, const size_t buf_len, t_section section, unsigned int follow)
{
	size_t length = 0;
	uint16_t int16;
	uint32_t int32;

	if (rr != NULL) {
	
		if (rr->section == section || section == SEC_ALL) {
			/* name */
			length = name2wire(rr->name, buf, offset, buf_len);
			
			/* type */
			int16 = htons(rr->type);
			memcpy(&buf[offset+length], &int16, 2);
			length += 2;
			
			/* class */
			int16 = htons(rr->class);
			memcpy(&buf[offset+length], &int16, 2);
			length += 2;
			
			/* ttl */
			/* if the rr is a question rr, don't put the ttl etc on the wire) */
			if (section != SEC_QUESTION) {
				int32 = htonl(rr->ttl);
				memcpy(&buf[offset+length], &int32, 4);
				length += 4;
			
				/* rdata */
				length += rdata2wire(rr, buf, offset+length, buf_len);
			}
		}
		if (follow == FOLLOW && rr->next) {
			length += rr2wire(rr->next, buf, offset+length, buf_len, section, follow);
		}
	}
	return length;
}

/**
 * Converts the given packet into wire format and places it in the buffer
 * (don't forget to free enough space in buf!)
 *
 * if you get a segfault on rr->section in rr2wire, you most probably have forgotten to to create_packet before
 */
size_t
packet2wire(struct t_dpacket *packet, uint8_t *buf, const size_t buf_len)
{
	size_t length = 0;
	int totalcount = 0;
	int16_t int16;
	/* id */
	int16 = htons(packet->id);
	memcpy(&buf[0], &int16, 2);
	/* QR */
	set_bit(&buf[2], 7, packet->flags.qr);
	/* opcode */
	set_bits(&buf[2], packet->flags.opcode >> 1, 6, 3);
	/* aa */
	set_bit(&buf[2], 2, packet->flags.aa);
	/* tc */
	set_bit(&buf[2], 1, packet->flags.tc);
	/* rd */
	set_bit(&buf[2], 0, packet->flags.rd);
	/* ra */
	set_bit(&buf[3], 7, packet->flags.ra);
	/* z */
	set_bit(&buf[3], 6, packet->flags.z);
	/* ad */
	set_bit(&buf[3], 5, packet->flags.ad);
	/* cd */
	set_bit(&buf[3], 4, packet->flags.cd);
	/* rcode */
	set_bits(&buf[3], packet->flags.rcode >> 4, 0, 3);
	/* qdcount */
	int16 = htons(packet->count[SEC_QUESTION]);
	memcpy(&buf[4], &int16, 2);
	/* ancount */
	int16 = htons(packet->count[SEC_ANSWER]);
	memcpy(&buf[6], &int16, 2);
	/* nscount */
	int16 = htons(packet->count[SEC_AUTH]);
	memcpy(&buf[8], &int16, 2);
	/* arcount */
	int16 = htons(packet->count[SEC_ADD]);
	memcpy(&buf[10], &int16, 2);
	
	length = 12;
	
	totalcount += packet->count[SEC_QUESTION];
	totalcount += packet->count[SEC_ANSWER];
	totalcount += packet->count[SEC_AUTH];
	totalcount += packet->count[SEC_ADD];

	if (totalcount > 0) {
		/* query */
		length += rr2wire(packet->rrs, buf, length, buf_len, SEC_QUESTION, FOLLOW);
		/* an records */
		length += rr2wire(packet->rrs, buf, length, buf_len, SEC_ANSWER, FOLLOW);
		/* ns records */
		length += rr2wire(packet->rrs, buf, length, buf_len, SEC_AUTH, FOLLOW);
		/* ad records */
		length += rr2wire(packet->rrs, buf, length, buf_len, SEC_ADD, FOLLOW);
	}	

	/* OPT rr let's do this byte by byte instead of using RR, to make
	   the RR functions less complicated */
	/* TODO: TCP */
	if (packet->udppacketsize > 512 || packet->opt.dnssec_ok != 0) {
		/* no owner */
		buf[length] = 0;
		/* type OPT */
		buf[length+1] = 0;
		buf[length+2] = 41;
		/* class is expected udp packet size */
		int16 = htons(packet->udppacketsize);
		memcpy(&buf[length+3], &int16, 2);
		/* xrcode */
		buf[length+5] = packet->opt.xrcode;
		/* version */
		buf[length+6] = packet->opt.version;
		/* z */
		int16 = packet->opt.z;
		memcpy(&buf[length+7], &int16, 2);
		/* DO bit */
		set_bit(&buf[length+7], 7, packet->opt.dnssec_ok);
		
		/* no rdata */
		buf[length+9] = 0;
		buf[length+10] = 0;
		
		/* increase add count */
		/* we could do this at once above, but that would be
		 * premature optimization ;)
		 */
		int16 = htons((packet->count[SEC_ADD]) + 1); 
		memcpy(&buf[10], &int16, 2);

		length += 11;
	}
	return length;
}

/**
 * Reads the given number of RR's from the buffer and places them in the given packet
 */
int
wire2rrs (uint8_t *buf, unsigned long offset, struct t_dpacket *packet, uint16_t qdcount, uint16_t ancount, uint16_t nscount, uint16_t adcount)
{
	unsigned long origoffset = offset;
	uint16_t i;
	uint32_t int32;
	uint8_t *minibuf;
	struct t_rr *rr;
	
	for (i=0; i<qdcount; i++) {
		assert(offset <= packet->udppacketsize);
		rr = wire2rr(buf, &offset, SEC_QUESTION);
		dpacket_add_rr(rr, SEC_QUESTION, packet);
	}

	for (i=0; i<ancount; i++) {
		assert(offset <= packet->udppacketsize);
		rr = wire2rr(buf, &offset, SEC_ANSWER);
		dpacket_add_rr(rr, SEC_ANSWER, packet);
	}

	for (i=0; i<nscount; i++) {
		assert(offset <= packet->udppacketsize);
		rr = wire2rr(buf, &offset, SEC_AUTH);
		dpacket_add_rr(rr, SEC_AUTH, packet);
	}

	for (i=0; i<adcount; i++) {
		assert(offset <= packet->udppacketsize);
		rr = wire2rr(buf, &offset, SEC_ADD);
		if (rr->type == TYPE_OPT) {
			if (rr->class > packet->udppacketsize) {
				SET_UDPSIZE(packet, rr->class);
			}
			
			int32 = rr->ttl;
			minibuf = xmalloc(4);
			memcpy(&minibuf[0], &int32, 4);
			packet->opt.dnssec_ok = bit_set(&minibuf[1], 7);
			xfree(minibuf);
			rr_destroy(rr, NO_FOLLOW);
			/* TODO: rest? */
			} else {
			dpacket_add_rr(rr, SEC_ADD, packet);
		}
	}
	
	/* OPT */
	assert(offset <= packet->udppacketsize);

	return (int) (offset-origoffset);
}

/**
 * Reads a domain name from the buffer and places is (in string format) in the specified buffer
 * Buffer should be malloced to contain getnamesize() elements
 */
size_t
wire2namestr (uint8_t *buf, size_t offset, char *name)
{
 
	size_t labeloffset = offset;
	uint8_t labelsize = buf[labeloffset];
	int namelength = 0;
	unsigned long origoffset = offset;
	int compression_used = 0;
	uint8_t srcpos;
	int destpos = 0;
	uint8_t *minibuf;
	uint16_t int16;
	
	while (labelsize > 0) {

		/* compression check */
		if (labelsize >= 192) {
			/* remember position to read rest of RR from later */
			if (!compression_used) {
				offset = labeloffset+2;
			}
			
			minibuf = xmalloc(2);
			minibuf[0] = buf[labeloffset];
			minibuf[1] = buf[labeloffset+1];
			set_bit(&minibuf[0], 7, 0);
			set_bit(&minibuf[0], 6, 0);
			memcpy(&int16, &minibuf[0], 2);
			xfree(minibuf);
			labeloffset = ntohs(int16);
			labelsize = buf[labeloffset];
			compression_used = 1;
		}

		for(srcpos = 0; srcpos < labelsize; srcpos++) {
			if (buf[labeloffset+srcpos+1] == 0) {
				name[destpos] = '\\';
				name[destpos+1] = '0';
				name[destpos+2] = '0';
				name[destpos+3] = '0';
				destpos += 4;
			} else if(buf[labeloffset+srcpos+1] == '.' ||
				buf[labeloffset+srcpos+1] == '(' ||
				buf[labeloffset+srcpos+1] == ')'
				) {
				name[destpos] = '\\';
				name[destpos+1] = (char) buf[labeloffset+srcpos+1];
				destpos+=2;
			} else if (isascii(buf[labeloffset+srcpos+1])) {
				name[destpos] = (char) buf[labeloffset+srcpos+1];
				destpos++;
			} else {
				name[destpos] = '\\';
				snprintf(&name[destpos+1], 4, "%03d", (int) buf[labeloffset+srcpos+1]);
				destpos += 4;
			}
		}
		namelength += labelsize;
	 
		name[destpos] = '.';
		destpos++;
		namelength += 1;
		labeloffset += labelsize+1;
		labelsize = buf[labeloffset];
	}

	if (!compression_used) {
		offset = labeloffset+1;
        }
        
        name[destpos] = '\0';
        destpos++;
        return (size_t) (offset - origoffset);
}


/**
 * Reads an RR from the buffer
 */
struct t_rr *
wire2rr (uint8_t *buf, unsigned long *offsetp, t_section section) 
{
	struct t_rdata *rrname;
	struct t_rr *rr;
	uint8_t *name;
	uint16_t int16;
	uint32_t int32;
	unsigned long offset = *offsetp;
	size_t namelength = getnamesize(buf, offset);

	name = xmalloc(namelength+1);
	offset += wire2namestr(buf, offset, (char *) name);
	rrname = rdata_create(name, namelength-1);
	xfree(name);
	rr = rr_create(rrname, 0, 0, section);
	rdata_destroy(rrname);
	rr->section = section;

	memcpy(&int16, &buf[offset], 2);
	rr->type = ntohs(int16);
	offset += 2;

	memcpy(&int16, &buf[offset], 2);
	rr->class = ntohs(int16);
	offset += 2;

	if (section != SEC_QUESTION) {
		memcpy(&int32, &buf[offset], 4);
		rr->ttl = ntohl(int32);
	
		offset += 4;
	
		offset += wire2rdata(buf, offset, rr);
	}
	
	*offsetp = offset;
	return rr;
}

/**
 * Reads a complete packet from the given buffer
 * Packet should be created() first
 */
int
wire2packet(uint8_t *buf, struct t_dpacket *packet, size_t len)
{
	uint16_t int16;
	uint16_t qdcount, ancount, nscount, adcount;

	memcpy(&int16, buf, 2);
	packet->id = ntohs(int16);
	packet->udppacketsize = len;
	
	/* QR */
	packet->flags.qr = bit_set(&buf[2], 7);
	/* opcode */
	packet->flags.opcode = (unsigned int) get_bits(&buf[2], 6, 3);
	/* aa */
	packet->flags.aa = bit_set(&buf[2], 2);
	/* tc */
	packet->flags.tc = bit_set(&buf[2], 1);
	/* rd */
	packet->flags.rd = bit_set(&buf[2], 0);
	/* ra */
	packet->flags.ra = bit_set(&buf[3], 7);
	/* z */
	packet->flags.z = bit_set(&buf[3], 6);
	/* ad */
	packet->flags.ad = bit_set(&buf[3], 5);
	/* cd */
	packet->flags.cd = bit_set(&buf[3], 4);
	/* rcode */
	packet->flags.rcode = (unsigned int) get_bits(&buf[3], 0, 3);
	/* qdcount */
	memcpy(&int16, &buf[4], 2);
	qdcount = ntohs(int16);
	/* ancount */
	memcpy(&int16, &buf[6], 2);
	ancount = ntohs(int16);
	/* nscount */
	memcpy(&int16, &buf[8], 2);
	nscount = ntohs(int16);
	/* arcount */
	memcpy(&int16, &buf[10], 2);
	adcount = ntohs(int16);

	wire2rrs(buf, 12, packet, qdcount, ancount, nscount, adcount);
	return RET_SUCCESS;
}
