/**
 * resolver.c
 *
 * contains resolver code 
 *
 * (c) NLnet Labs, 2004
 * 
 * See the file LICENSE for the license
 *
 */

#include "common.h"
#include <assert.h>
#include <netdb.h>
#include <sys/types.h>
#include <sys/socket.h>

/**
 * Performs a query for the given query RR
 */
struct t_dpacket *
do_query_rr(struct t_rr *query, struct t_rr *nameservers, int protocol)
{
	struct t_dpacket *p;
	struct t_dpacket *q;
	
	p = dpacket_create();

	/* set values...
	 * TODO: get them from options or arguments (?)
	 */
        SET_RD(p);

	if (drill_opt->dnssec) {
		vverbose("Setting DNSSEC to ON");
		SET_DNSSEC(p);
		if (protocol == PROTO_UDP)
			SET_UDPSIZE(p, drill_opt->bufsize);
	
	}
	
	/* other reasons to use CD=1? or maybe always? */
	if (drill_opt && drill_opt->purpose == TRACE) {
		SET_CD(p);
	}

	dpacket_add_rr(rr_clone(query, NO_FOLLOW), SEC_QUESTION, p);

	q = send_packet(p, nameservers, protocol, NULL);
	
	dpacket_destroy(p);

	return q;
}

/**
 * Performs a query for the given name and type
 */
struct t_dpacket *
do_query(struct t_rdata *name, int type, struct t_rr *nameservers, int protocol)
{
	struct t_rr *query;
	struct t_dpacket *answer;
	
	query = rr_create(name, (uint16_t) type, DEF_TTL, SEC_QUESTION);
	
	answer = do_query_rr(query, nameservers, protocol);
	
	rr_destroy(query, NO_FOLLOW);
	
	return answer;
}

/**
 * Checks whether the given char* is an ip address (such as 127.0.0.1)
 * or xxxx:xxxx:: etc 
 * beter, cast to inaddr struct and check return value - see NSD for that
 */ 
int
is_ip_address(const char *address)
{
	int a, i;
	
	if (address == NULL) {
		return 0;
	} else {
		if (index(address, ':') == NULL) {
			/* IPv4? */
			if (strlen(address) < 7 || strlen(address) > 15) {
				return 0;
			} else {
				for (i = 0; i < 4; i++) {
					a = atoi(address);
					if (a == 0 && *address != '0') {
						return 0;
					}
					if (a < 0 || a > 255) {
						return 0;
					}
					address = index(address, '.')+1;
				}
				return 1;
			}
		} else {
			/* IPv6? */
			for (i=0; i< (int) strlen(address); i++) {
				if (!(isalnum(address[i]) || address[i] == ':')) {
					return 0;
				}
			}
			return 1;
		} 
	}
}

/* bad temp function for initial test or something
   do not keep */
int is_ipv6_addr(const char *addr) {
	if (index(addr, ':') != NULL) {
		return 1;
	}
	return 0;
}

int is_ipv6_addr_rdata(struct t_rdata *addr) {
	char *a = rdata2str(addr);
	if (index(a, ':') != NULL) {
		xfree(a);
		return 1;
	}
	xfree(a);
	return 0;
}

/**
 * Handles the answer packet (ie. performs security checks if necessary etc)
 */
int
handle_answer_packet(char *queryname, int type, struct t_dpacket *answer, struct t_rr *ns, struct t_rr
		*trusted_keys, struct zone_list *sz, int protocol, struct
		t_rr **result_rr) {
	struct t_rdata *rdata;
	struct t_rr *result = NULL;
	struct t_rr *alias;
	char *cname;
	char *resultname;
	int result_int = 0;
	
	rdata = rdata_create((uint8_t *) queryname, strlen(queryname));

	result = dpacket_get_rrset(rdata, (uint16_t) type, answer, SEC_ANSWER);

	if (result) {

		resultname = rdata2str(result->name);
		if (contains_zone_list(sz, resultname)) {

			verbose("ZONE %s MUST BE SECURE\n", rdata2str(result->name));
			
			if (do_chase_rr(result, answer, ns, trusted_keys,
				protocol) == RET_SUC) {
				result_int = RET_SUC;
			} else {
				warning("\nWARNING: Unable to verify the address given for %s\n\n", rdata2str(result->name));
				#ifdef EAI_SYSTEM
				result_int = EAI_SYSTEM;
				#else
				result_int = -11;
				#endif
			}
			
		}
		xfree(resultname);
		
	} else {
		/* maybe CNAME? */
		result = dpacket_get_rrset(rdata, TYPE_CNAME, answer, SEC_ANSWER);
		if (result) {
			cname = rdata2str(result->rdata[0]);
			
			alias = result;
			
			result_int = handle_answer_packet(cname, type, answer, ns,
					trusted_keys, sz, protocol, &result);
			
			(void) rr_add_rr(result, alias);
			
			xfree(cname);
		} else {
			
		}
		
		/* TODO maybe delegations? */
	}
	
	rdata_destroy(rdata);
	
	if (result_int == RET_SUC) {
		*result_rr = result;
	}
	if (result_int == RET_SUC) {
		return (int) answer->flags.rcode;
	} else {
		return result_int;
	}
}

/**
 * resolves the given hostname to an address, resturns NULL if it cannot be
 * found
 *
 * if secure trusted_keys is not null, it will try to chase the signature of
 * the a record found and only return an address if it is valid (?)
 */
char *
resolve2string(char *hostname, struct t_rr *ns, struct t_rr *trusted_keys,
		struct zone_list *search, struct zone_list *sz, int protocol)
{
	struct t_rr *result;
	
	char *lh = xmalloc(9*sizeof(char));
	strncpy(lh, "localhost", 9);
	if (strlen(hostname) == 9 && strncmp(hostname, lh, 9) == 0) {
		return lh;
	} else {
		xfree(lh);
	}
	
	/* TODO check result? */
	resolve2rr(hostname, ns, trusted_keys, search, sz, protocol, 0, &result);

	if (result) {
		return rdata2str(result->rdata[0]);
	} else {
		return NULL;
	}
}

/**
 * Resolve the given host name to an rr (or rrset)
 */
int
resolve2rr2(char *queryname, struct t_rr *ns, struct t_rr *trusted_keys,
		struct zone_list *sz, int protocol, int type, struct t_rr
		**result_rr)
{
	struct t_dpacket *answer = NULL;
	struct t_rdata *queryname_rdata;
	struct t_rr *result = NULL;
	struct t_rr *result2 = NULL;
	int result_int;
	
	/* add last dot if necessary */
	verbose("resolving: %s\n", queryname);
	
	assert(ns != NULL);

	queryname_rdata = rdata_create((uint8_t *) queryname, strlen(queryname));
	
	if (type != AF_INET) {

		answer = do_query(queryname_rdata, TYPE_AAAA, ns, protocol);
	
		if (!answer) {
			error("No response from server (protocol %d): \n",protocol);
			print_rr(ns, NO_FOLLOW);
			perror("response error");
			*result_rr = NULL;
			return RET_FAIL;
		} else {
			/* TODO: error handling
			 * TODO: make get_rrset for char * in dns.c?
		 	 *
			 *printf("'%s' has %d chars\n", hostname, strlen(hostname)-1);
			 *printf("'%s' has %d chars\n", queryname, strlen(queryname)-1);
			 *
			 * TODO check return value
			 */
			result_int = handle_answer_packet(queryname, TYPE_AAAA, answer, ns,
					trusted_keys, sz, protocol, &result);

			dpacket_destroy(answer);
		}
	}
	
	if (type != AF_INET6) {
		//if (answer != NULL) {
		//	dpacket_destroy(answer);
		//}

		answer = do_query(queryname_rdata, TYPE_A, ns, protocol);

		if (answer) {
			/* TODO: check return value */
			result_int = handle_answer_packet(queryname, TYPE_A, answer, ns,
				trusted_keys, sz, protocol, &result2);
			dpacket_destroy(answer);
		}
	}
	/* ipv6 first */
	if (result == NULL) {
		*result_rr = result2;
	} else {
		(void) rr_add_rr(result, result2);
		*result_rr = result;
	}

	rdata_destroy(queryname_rdata);
	return result_int;

}

int
resolve2rr(char *hostname, struct t_rr *ns, struct t_rr *trusted_keys,
		struct zone_list *search,
		struct zone_list *sz, int protocol, int type, struct t_rr
		**result_rr)
{
	char *queryname;
	int i;
	int result;
	struct t_rdata  *rd;
	struct t_hostlist *hostlist;
	const char *ipaddr = NULL;
	
	/* check /etc/hosts */
	/* TODO: do this only once and not at every call to resolve() */
	hostlist = read_hosts_file("/etc/hosts");
	ipaddr = hostlist_get_ipaddr(hostlist, hostname);
	if (ipaddr != NULL) {
		if (is_ipv6_addr(ipaddr)) {
			rd = rdata_create((uint8_t *) hostname, strlen(hostname));
			*result_rr = rr_create(rd, TYPE_AAAA, 0, SEC_ANSWER);
			rr_add_rdata(rdata_create((uint8_t *) ipaddr, strlen(ipaddr)), *result_rr);
			hostlist_destroy(hostlist);
			rdata_destroy(rd);
			return RET_SUC;
		} else {
			rd = rdata_create((uint8_t *) hostname, strlen(hostname));
			*result_rr = rr_create(rd, TYPE_A, 0, SEC_ANSWER);
			rr_add_rdata(rdata_create((uint8_t *) ipaddr, strlen(ipaddr)), *result_rr);
			hostlist_destroy(hostlist);
			rdata_destroy(rd);
			return RET_SUC;
		}
	}
	hostlist_destroy(hostlist);

	/* TODO: ipv6? */
	if (strlen(hostname) == 9 && strncmp(hostname, "localhost", 9) == 0) {
		rd = rdata_create((uint8_t *) "localhost", 9);
		*result_rr = rr_create(rd, TYPE_A, 0, SEC_ANSWER);
		
		rr_add_rdata(rdata_create((uint8_t *) "127.0.0.1", 9), *result_rr);
		rdata_destroy(rd);
		return RET_SUC;
	}
	
	
	
	if (hostname[strlen(hostname)-1] != '.') {
		for (i=0; i<search->size; i++) {
			queryname = xmalloc(strlen(hostname)+strlen(search->zones[i])+3);
			strncpy(queryname, hostname, strlen(hostname));
			queryname[strlen(hostname)] = '.';
			strncpy(queryname+strlen(hostname)+1, search->zones[i], strlen(search->zones[i])+1);
			if (queryname[strlen(hostname)+strlen(search->zones[i])] != '.') {
				queryname[strlen(hostname)+strlen(search->zones[i])+1] = '.';
				queryname[strlen(hostname)+strlen(search->zones[i])+2] = '\0';
			}
			
			result = resolve2rr2(queryname, ns, trusted_keys, sz, protocol, type, result_rr);
			xfree(queryname);

			if (result == RET_SUC) {
				return result;
			}
		}

		/* none of the search zones worked (if any), try plain . */
		
		queryname = xmalloc(strlen(hostname)+2);
		memset(queryname, 0, strlen(hostname)+2);
		memcpy(queryname, hostname, strlen(hostname));
		/*queryname = xstrdup(hostname);*/
		if (queryname[strlen(hostname)-1] != '.') {
			queryname[strlen(hostname)] = '.';
			queryname[strlen(hostname)+1] = '\0';
		}
		result = resolve2rr2(queryname, ns, trusted_keys, sz, protocol, type, result_rr);
			
		xfree(queryname);

		return result;
	} else {

		result = resolve2rr2(hostname, ns, trusted_keys, sz, protocol, type, result_rr);
	
		return result;
	}
}
