/**
 * net.c 
 * Contains all the send/receive stuff
 *
 * (c) NLnet Labs 2004
 *
 * See the file LICENSE for the license
 */


#include "common.h"

#include <arpa/inet.h>
/**
 * Default socket timeout in seconds
 */
#define DEFAULT_TIMEOUT_SEC 8
#define DEFAULT_TIMEOUT_USEC 0

/*
 * if start_or_close = 0
 *     send axfr packet on a socket
 *     return the first packet received
 * if start_or_close = 1 
 *     get next packet
 * if start_or_close =2
 *     close the connection
 */
struct t_dpacket *
send_axfr_packet(struct t_dpacket *axfr, struct t_rr *ns,int start_or_close,
		int *psocket)
{
	/* TODO closing of socket */
	return send_packet(axfr, ns, PROTO_TCP, psocket);
}

/**
 * Sends the packet, and returns the response packet
 * returns NULL if something failed
 * protocol is either TCP or UDP
 */
struct t_dpacket *
send_packet(struct t_dpacket *p, struct t_rr *ns, int protocol, int *psocket)
{
	struct t_dpacket *fromwire;
	uint8_t *recbuf;
	uint8_t *o_recbuf;
	size_t recsize;
	struct t_rr *cur_ns = ns;
	int send_result = RET_FAIL;
	char *srvip;
	
        struct timeval tv_s; 
        struct timeval tv_e; 
	
	if (drill_opt->verbose > 1) {
		srvip = rdata2str(ns->rdata[0]);
		mesg("Sending packet to %s:", srvip);
		print_packet(p);
		xfree(srvip);
	}
	recbuf = xmalloc(MAX_PACKET);
	memset(recbuf, 0, MAX_PACKET);

	gettimeofday(&tv_s, NULL);

	while (send_result == RET_FAIL && cur_ns != NULL) {
		send_result = sendq_ns(p, cur_ns, &recbuf, &recsize, protocol, psocket);
		cur_ns = cur_ns->next;
	}
	
	o_recbuf = recbuf;
	gettimeofday(&tv_e, NULL);

	/* send the packet of */
	if (send_result == RET_SUC) {
		/* *yech* *yech*, if tcp is used inc the buffer by 2 */
		if (have_drill_opt && protocol == PROTO_TCP)
			recbuf += sizeof(uint16_t);

		fromwire = dpacket_create();

		wire2packet(recbuf, fromwire, recsize);

		fromwire->udppacketsize = recsize;

		fromwire->serverip = rdata2str(ns->rdata[0]);

		if (drill_opt->verbose > 0) {
			verbose("Received packet:");
			print_packet(fromwire);
		}

		/* #id check */
		if (p->id != fromwire->id)
			warning("packet ID mismatch %d != %d", p->id, fromwire->id);
		
		fromwire->querytime = (tv_e.tv_sec - tv_s.tv_sec)*1000;
		fromwire->querytime += (tv_e.tv_usec - tv_s.tv_usec)/1000;
		
		xfree(o_recbuf); /* free original data */

		if (fromwire != NULL) {

			if (GET_TC(fromwire) == ON) {
				/* do this here so that we have more control on 
				 * wether this is really needed. Maybe check if
				 * the (name,type) we ask really is in the packet
				 */
				if (protocol == PROTO_ANY) {
					mesg("Truncated packet; redoing in TCP mode");
					dpacket_destroy(fromwire);
					fromwire = send_packet(p, ns, PROTO_TCP, NULL);
				} else if (protocol == PROTO_TCP) {
					mesg("What the .. : a truncated TCP packet??");
				} else {
					if (have_drill_opt && drill_opt->verbose > 0) {
						warning("Truncated packet, but runnig udp only, return it");
					}
				}
					
			}
		}
		return fromwire;
	} else {
		xfree(o_recbuf);
		return NULL;
	}
}

/**
 * Sends the packet data over UDP IPv3
 */
int
sendq_udp(struct t_dpacket *pkt, struct t_rdata *srvip, uint8_t **reply, size_t *size)
{
	uint8_t *buf;
	size_t bufsize;
	int result;
	
	buf = xmalloc(MAX_PACKET); 
	memset(buf, 0, MAX_PACKET);
	bufsize = packet2wire(pkt, buf, MAX_PACKET);
	result = send_udp_raw(buf, bufsize, srvip, reply, size);
	xfree(buf);
	return result;
}

/**
 * Sends the packet data over UDP IPv6
 */
int
sendq_udp6(struct t_dpacket *pkt, struct t_rdata *srvip, uint8_t **reply, size_t *size)
{
	uint8_t *buf;
	size_t bufsize;
	int result;
	
	buf = xmalloc(MAX_PACKET); 
	memset(buf, 0, MAX_PACKET);
	bufsize = packet2wire(pkt, buf, MAX_PACKET);
	result = send_udp_raw6(buf, bufsize, srvip, reply, size);
	xfree(buf);
	return result;
}


/**
 * Does the actual udp sending (IPv6)
 * size of reply is put in reply_size
 * The IPv6 Edition
 */
int
send_udp_raw6(uint8_t *sendbuf, size_t bufsize, struct t_rdata *srvip, 
		uint8_t **reply, size_t *reply_size)
{
	struct sockaddr_in6 src, dest;
	socklen_t frmlen;
	int sockfd;
	size_t amount;
	struct timeval timeout;
	char *address;
	
	timeout.tv_sec = DEFAULT_TIMEOUT_SEC;
	timeout.tv_usec = DEFAULT_TIMEOUT_USEC;
	
	/* prepare locally */
	src.sin6_family = AF_INET6;
	src.sin6_addr = in6addr_any;
	src.sin6_port = htons(0);

	dest.sin6_family = AF_INET6;
	if (have_drill_opt) {
		dest.sin6_port = htons(drill_opt->port); 
	} else {
		dest.sin6_port = htons(DNS_PORT);
	}
	
	address = rdata2str(srvip);
	
	if (!inet_pton(AF_INET6, address, &dest.sin6_addr)) {
		error("Bad address: %s", address);
		perror("inet_pton");
	}
	
	xfree(address);

	frmlen = (socklen_t) sizeof(src); /* for use in recvfrom */
	
	if ((sockfd = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)) == -1) {
		error("(udp/ipv6) could not create socket");
		perror("socket");
		return RET_FAILURE;
	}
	/* do block, plz */
	/* MG
	if (ioctl(sockfd, FIONBIO, &nonblck) == -1 )
		return RET_FAILURE;
	*/

	if (bind(sockfd,(struct sockaddr*)&src, frmlen) == -1) {
		close(sockfd);
		perror("IPv6 bind failed");
		return RET_FAILURE;
	}
	
	if (drill_opt->verbose > 1) {
		mesg("Sending bytes (to ip %s):", rdata2str(srvip));
		print_bytes_hex(sendbuf, 0, (int) bufsize, 20);
	}
	
	amount = (size_t) sendto(sockfd, sendbuf, bufsize, 0, 
			(struct sockaddr*) &dest, (socklen_t) sizeof(dest));

	if (bufsize != amount) {
		warning("IPv6 send returned wrong size: %d instead of %d",
			(int) amount, (int) bufsize);
		close(sockfd);
		perror("sendto");
		return RET_FAIL;
	}
	
	if (setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
			(socklen_t) sizeof(timeout))) {
		close(sockfd);
		perror("setsockopt");
		return RET_FAIL;
	}
	
	*reply_size = (size_t) recvfrom(sockfd, *reply, MAX_PACKET, 0, /* flags */ 
			(struct sockaddr*) &src, &frmlen);

	close(sockfd);

	if (*reply_size == (size_t) -1) {
		return RET_FAIL;
	}

	if (drill_opt->verbose > 1) {
		mesg("Received bytes:");
		print_bytes_hex(*reply, 0, (int) *reply_size, 20);
	}

	if (*reply_size == 0) {
		return RET_FAIL;
	} else {
		return RET_SUC;
	}
}


/**
 * Does the actual udp sending (IPv4)
 * size of reply is put in reply_size
 */
int
send_udp_raw(uint8_t *sendbuf, size_t bufsize, struct t_rdata *srvip, 
		uint8_t **reply, size_t *reply_size)
{
	struct sockaddr_in src, dest;
	socklen_t frmlen;
	int sockfd;
	size_t amount;
	struct timeval timeout;
	char *serverstr;
	
	
	timeout.tv_sec = DEFAULT_TIMEOUT_SEC;
	timeout.tv_usec = DEFAULT_TIMEOUT_USEC;
	
	/* prepare locally */
	src.sin_family = AF_INET;
	src.sin_addr.s_addr = (in_addr_t) htonl(INADDR_ANY);
	src.sin_port = htons(0);

	dest.sin_family = AF_INET;
	if (have_drill_opt) {
		dest.sin_port = htons(drill_opt->port); 
	} else {
		dest.sin_port = htons(DNS_PORT);
	}
	
	serverstr = rdata2str(srvip);
	dest.sin_addr.s_addr = inet_addr(serverstr);
	xfree(serverstr);
	
	frmlen = (socklen_t) sizeof(src); /* for use in recvfrom */
	
	if ((sockfd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1) {
		perror("could not create socket");
		return RET_FAILURE;
	}
	/* do block, plz */
	/* MG
	if (ioctl(sockfd, FIONBIO, &nonblck) == -1 )
		return RET_FAILURE;
	*/

	if (bind(sockfd,(struct sockaddr*)&src, (socklen_t) sizeof(src)) == -1) {
		close(sockfd);
		return RET_FAILURE;
	}

	if (drill_opt->verbose > 1) {
		serverstr = rdata2str(srvip);
		mesg("Sending bytes (to ip %s):", serverstr);
		print_bytes_hex(sendbuf, 0, (int) bufsize, 20);
		xfree(serverstr);
	}
	
	amount = (size_t) sendto(sockfd, sendbuf, bufsize, 0, 
			(struct sockaddr*) &dest, (socklen_t) sizeof(dest));

	if (bufsize != amount) {
		close(sockfd);
		return RET_FAIL;
	}

	if (setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
			(socklen_t) sizeof(timeout))) {
		close(sockfd);
		perror("setsockopt");
		return RET_FAIL;
	}
	
	*reply_size = (size_t) recvfrom(sockfd, *reply, MAX_PACKET, 0, /* flags */ 
			(struct sockaddr*) &src, &frmlen);
	close(sockfd);

	if (*reply_size == (size_t) -1) {
		return RET_FAIL;
	}

	if (drill_opt->verbose > 1) {
		mesg("Received bytes:");
		print_bytes_hex(*reply, 0, (int) *reply_size, 20);
	}

	
	if (*reply_size == 0) {
		return RET_FAIL;
	} else {
		return RET_SUC;
	}
}



/**
 * Sends the packet to a server over tcp (IPv6)
 */
int
sendq_tcp6(struct t_dpacket *pkt, struct t_rdata *srvip, uint8_t **reply, size_t *size, int *psocket)
{
	struct sockaddr_in6 dest;
	int sockfd;
	uint8_t *buf;
	size_t bufsize;
	size_t amount;
	uint16_t tcpbufsize; /* for tcp we must send the length first */
	char *serverstr;
	
	/* prepare locally */
	dest.sin6_family = AF_INET6;
	serverstr = rdata2str(srvip);
	inet_pton(AF_INET6, serverstr,&(dest.sin6_addr));
	xfree(serverstr);
	
	if (have_drill_opt) {
		dest.sin6_port = htons(drill_opt->port);
	} else {
		dest.sin6_port = htons(DNS_PORT);
	}

	if (psocket == NULL || *psocket == 0) {
		if ((sockfd = socket(AF_INET6, SOCK_STREAM, 0)) == -1) {
			perror("(tcp/ipv6) could not create socket");
			return RET_FAILURE;
		}

		if (connect(sockfd, (void*)&dest, (socklen_t) sizeof(dest)) == -1) {
			perror("(tcp/ipv6) could not connect");
			return RET_FAILURE;
		}
	
		if (psocket) {
			*psocket = sockfd;
		} else {
			psocket = &sockfd;
		}
	}  else {
		sockfd = *psocket;
	}


	buf = xmalloc(MAX_PACKET); 
	memset(buf, 0, MAX_PACKET);
	bufsize = packet2wire(pkt, buf, MAX_PACKET);
	tcpbufsize = htons(bufsize);
	
	if (drill_opt->verbose > 1) {
		mesg("Sending bytes:");
		print_bytes_hex(buf, 0, (int) bufsize, 20);
	}
	
	/* send the tcp length */
	amount = (size_t) send(sockfd, &tcpbufsize, (socklen_t) sizeof(uint16_t), 0);
	if (sizeof(uint16_t) != amount)
		return RET_FAIL;
	
	amount = (size_t) send(sockfd, buf, bufsize, 0);
	if (bufsize != amount)
		return RET_FAIL;

	/* as this is a tcp reply, the first 2 bytes are the length of the
	 * packet */
	*size = (size_t) recv(sockfd, *reply, 2, 0);
	if (*size != 2) {
		warning("(ipv6/tcp) recv size not 2 (%d)", (int) *size);
		return RET_FAIL;
	}

	memcpy(&tcpbufsize, *reply, sizeof(uint16_t));
	tcpbufsize = ntohs(tcpbufsize);

	if (tcpbufsize > MAX_PACKET) {
		*reply = xrealloc(*reply, tcpbufsize+2);
		memset(*reply, 0, tcpbufsize+2);
		tcpbufsize = htons(tcpbufsize);
		memcpy(*reply, &tcpbufsize, sizeof(uint16_t));
		tcpbufsize = ntohs(tcpbufsize);
	}

	*size = (size_t) recv(sockfd, *reply+2, tcpbufsize, 0);
	while (*size < tcpbufsize) {
		*size += recv(sockfd, *reply+*size+2, tcpbufsize, 0);
	}

	/* TODO match size with tcpbufsize */
	xfree(buf);
	
	if (*size == (size_t) 0 || *size == (size_t) -1 || *size != tcpbufsize) {
		perror("(ipv6/tcp) recv");
		return RET_FAIL;
	} else {

		if (drill_opt->verbose > 1) {
			mesg("Received bytes:");
			print_bytes_hex(*reply, 0, (int) *size, 20);
		}
		return RET_SUC;
	}
}


/**
 * Sends the packet to a server over tcp (IPv4)
 */
int
sendq_tcp(struct t_dpacket *pkt, struct t_rdata *srvip, uint8_t **reply, size_t *size, int *psocket)
{
	struct sockaddr_in dest;
	int sockfd;
	uint8_t *buf;
	size_t bufsize;
	size_t amount;
	uint16_t tcpbufsize; /* for tcp we must send the length first */
	char *serverstr;
	struct timeval timeout;
	int ipv = AF_INET;
	
	timeout.tv_sec = DEFAULT_TIMEOUT_SEC;
	timeout.tv_usec = DEFAULT_TIMEOUT_USEC;
	
	dest.sin_family = (sa_family_t) ipv;
	serverstr = rdata2str(srvip);
	inet_pton(ipv, serverstr,&(dest.sin_addr));
	xfree(serverstr);
	
	if (have_drill_opt) {
		dest.sin_port = htons(drill_opt->port);
	} else {
		dest.sin_port = htons(DNS_PORT);
	}

	if (psocket == NULL || *psocket == 0) {
		if ((sockfd = socket(ipv, SOCK_STREAM, 0)) == -1) {
			perror("(tcp/ipv4) could not create socket");
			return RET_FAILURE;
		}

		if (setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
				(socklen_t) sizeof(timeout))) {
			close(sockfd);
			perror("setsockopt");
			return RET_FAIL;
		}

		if (connect(sockfd, (struct sockaddr *)&dest, (socklen_t) sizeof(dest)) == -1) {
			perror("(tcp/ipv4) could not connect");
			return RET_FAILURE;
		}

		if (psocket) {
			*psocket = sockfd;
		} else {
			psocket = &sockfd;
		}


		buf = xmalloc(MAX_PACKET); 
		memset(buf, 0, MAX_PACKET);

		bufsize = packet2wire(pkt, buf, MAX_PACKET);
		tcpbufsize = htons(bufsize);
		if (drill_opt->verbose > 1) {
			mesg("Sending bytes:");
			print_bytes_hex(buf, 0, (int) bufsize, 20);
		}
		
		/* send the tcp length */
		amount = (size_t) send(sockfd, &tcpbufsize, sizeof(uint16_t), 0);
		if (sizeof(uint16_t) != amount)
			return RET_FAIL;
		
		amount = (size_t) send(sockfd, buf, bufsize, 0);
		if (bufsize != amount)
			return RET_FAIL;

		/* TODO match size with tcpbufsize */
		xfree(buf);
	} else {
		sockfd = *psocket;
	}

	/* as this is a tcp reply, the first 2 bytes are the length of the
	 * packet */
	*size = (size_t) recv(sockfd, *reply, 2, 0);

	if (*size == 2) {
		memcpy(&tcpbufsize, *reply, sizeof(uint16_t));
		tcpbufsize = ntohs(tcpbufsize);
	} else if (*size == 1) {
		/* Sometimes (a certain bind 8.2 server did it) we only recv the first byte
		 * of the (2-byte) length, so we have to recv() again to get the second part
		 */
		*size = (size_t) recv(sockfd, *reply+1, 1, 0);
		memcpy(&tcpbufsize, *reply, sizeof(uint16_t));
		tcpbufsize = ntohs(tcpbufsize);
	} else 	{
		warning("(ipv4/tcp) recv size > 2 (%d)", (int) *size);
		return RET_FAIL;
	}

	if (tcpbufsize > MAX_PACKET) {
		*reply = xrealloc(*reply, tcpbufsize+2);
		memset(*reply, 0, tcpbufsize+2);
		tcpbufsize = htons(tcpbufsize);
		memcpy(*reply, &tcpbufsize, sizeof(uint16_t));
		tcpbufsize = ntohs(tcpbufsize);
	}
	*size = (size_t) recv(sockfd, *reply+2, tcpbufsize, 0);

	while (*size < tcpbufsize)
		*size += recv(sockfd, *reply+*size+2, tcpbufsize-*size, 0);

	if (*size == (size_t) 0 || *size == (size_t) -1) {
		warning("(ipv4/tcp) recv: size (%d) should be %d", (int) *size, tcpbufsize);
		return RET_FAIL;
	} else {
		if (drill_opt->verbose > 1) {
			mesg("Received bytes:");
			print_bytes_hex(*reply, 0, (int) *size, 20);
		}
		return RET_SUC;
	}
}

/**
 * try each nameserver inturn until a working one is found
 * RRT is not done - is this needed?
 */
int 
sendq_ns(struct t_dpacket *p, struct t_rr *ns, uint8_t **reply, size_t *size,
		int protocol, int *psocket)
{
	struct t_rr *tmp = ns;
	int ip6;
	int result = RET_FAILURE;

	/* currently this function is used WRONG - the nameserver cycling
	 * should take place here, but this has not been used */
	for(; tmp != NULL; tmp = tmp->next) {
		/* rdata[0] is the address of the A or AAAA record */
		ip6 = is_ipv6_addr_rdata(tmp->rdata[0]);

		verbose("Asking nameserver: %s\n", rdata2str(tmp->rdata[0]));
		
		switch (protocol) {
			case PROTO_ANY:
			case PROTO_UDP:
				switch(drill_opt->transport) {
					case 4:
						if (ip6) {
							warning("%s", "-4 enabled, IPv6 address encountered");
							break;
						}
						result = sendq_udp(p, tmp->rdata[0], reply, size);
					case 6:
						if (!ip6) {
							warning("%s", "-6 enabled, IPv4 address encountered");
							break;
						}
						result =  sendq_udp6(p, tmp->rdata[0], reply, size);
					default:
						if (ip6) {
							return sendq_udp6(p, tmp->rdata[0], reply, size);
						} else {
							return sendq_udp(p, tmp->rdata[0], reply, size);
						}
				}
				break;
			case PROTO_TCP:
				switch(drill_opt->transport) {
					case 4:
						if (ip6) {
							warning("%s", "-4 enabled, IPv6 address encountered");
							break;
						}
						result = sendq_tcp(p, tmp->rdata[0], reply, size, psocket);
					case 6:
						if (!ip6) {
							warning("%s", "-6 enabled, IPv4 address encountered");
							break;
						}
						result = sendq_tcp6(p, tmp->rdata[0], reply, size, psocket);
					default:
						if (ip6) {
							return sendq_tcp6(p, tmp->rdata[0], reply, size, psocket);
						} else {
							return sendq_tcp(p, tmp->rdata[0], reply, size, psocket);
						}
				}
				break;
			default:
				error("Unknown protocol: %d", protocol);
		}
		if (RET_SUCCESS == result)
			return RET_SUCCESS;
		/* go on with the next nameserver, unless -a is specified */
		if (1 == drill_opt->fail)
			break;
	}
	return RET_FAILURE;
}

/**
 * try to reconstruct a packet from given file and
 * return stat.
 * returns NULL if something failed
 *
 * This is more or less a debug function
 */
struct t_dpacket *
file_packet(char *filename)
{
	struct t_dpacket *fromwire;
	uint8_t *recbuf = xmalloc(MAX_PACKET);
	size_t recsize;

        struct timeval tv_s;
        struct timeval tv_e;

	recsize = packetbuffromfile(filename, recbuf);
	fromwire = dpacket_create();

	wire2packet(recbuf, fromwire, recsize);

	fromwire->udppacketsize = recsize;
	fromwire->serverip = xstrdup(filename);

	if (drill_opt->verbose > 0) {
		mesg("Received packet:");
		print_packet(fromwire);
	}

	fromwire->querytime = (tv_e.tv_sec - tv_s.tv_sec)*1000;
	fromwire->querytime += (tv_e.tv_usec - tv_s.tv_usec)/1000;
		
	xfree(recbuf); /* free original data */

	return fromwire;
}
