/*
* ip_util.c  (was ipv4_util.c)

Copyright (C) 2008-2023 Alessandro Vesely

This file is part of Ipqbdb.

Ipqbdb is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Ipqbdb is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Ipqbdb.  If not, see <http://www.gnu.org/licenses/>.

*/

#include "ip_util.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <limits.h>
#include <ctype.h>
#include <endian.h>
#include <arpa/inet.h> // for htonl (or is it <netinet/in.h>?)

#include <assert.h>

static int check_ipv4_c_or_m(unsigned char ip[16])
{
	for (int i = 0; i < 10; ++i)
		if (ip[i] != 0)
			return 6;

	/*
	* IPv4-compatible and IPv4-mapped addresses (RFC 4291) 
	* are to be stored in the IPv4 database.
	*/
	if (ip[10] != ip[11] || (ip[10]!= 0xffU && ip[10] != 0))
		return 6;

	if (ip[10] == 0 && ip[12] == 0 && ip[13] == 0 &&
		ip[14] == 0 && ip[15] == 0)
			return 0;

	memmove(&ip[0], &ip[12], 4);
	return 4;
}

int my_inet_pton(char const *p, ip_u *u)
/*
* Like inet_pton, return 1 on success, 0 or -1 on error.
* Guess address family.
*/
{
	assert(p);
	assert(u);

	memset(u, 0, sizeof *u);
	if (strchr(p, ':') == NULL)
	{
		int rtc = inet_pton(AF_INET, p, u->u.ipv4);
		if (rtc == 1)
			u->ip = 4;
		return rtc;
	}

	int rtc = inet_pton(AF_INET6, p, u->u.ipv6);
	if (rtc <= 0)
		return rtc;

	u->ip = check_ipv4_c_or_m(u->u.ipv6);

	return rtc;
}

char const *my_inet_ntop(ip_u *ip_addr, char addr[INET6_ADDRSTRLEN])
{
	assert(ip_addr);

	char const *a;
	if (ip_addr->ip == 4)
		a = inet_ntop(AF_INET, ip_addr->u.ipv4, addr, INET6_ADDRSTRLEN);
	else if (ip_addr->ip == 6)
		a = inet_ntop(AF_INET6, ip_addr->u.ipv6, addr, INET6_ADDRSTRLEN);
	else
		a = "invalid address";

	return a;
}

void first_in_range(unsigned char ip[16], int plen)
/*
* Zero all bits beyond plen.
*/
{
	assert(plen >= 0);
	assert(plen <= 128);

	int tail = 128 - plen;
	int ndx = 15;
	while (tail > 8)
	{
		ip[ndx] = 0;
		ndx -= 1;
		tail -= 8;
	}

	unsigned char mask = ~((1 << tail) - 1);
	ip[ndx] &= mask;
}


char const *parse_ip_invalid_what(int err)
{
	switch (err)
	{
		case parse_ip_ok: return "no error";
		case parse_ip_invalid_too_high: return "IP address with high octet";
		case parse_ip_invalid_too_long: return "IP address with many octets";
		case parse_ip_invalid_separator: return "separator after IP address";
		case parse_ip_invalid_range: return "IP range";
		case parse_ip_invalid_cidr: return "IP CIDR";
		case parse_ip_invalid_ipv6_1: return "Invalid IPv6 address";
		case parse_ip_invalid_ipv6_2: return "Invalid IPv6 end range";
		default: return "internal error";
	}
}

static int parse_ip4_address(char *ip_address, ip_range *ip, char **term)
{
	int i, sep, mode_flag = 0;
	union bytewise
	{
		uint32_t ip;
		unsigned char b[4];
	} uf, ul, *un = &uf;
	char *p = ip_address;

	memset(&ul, 0, sizeof ul);
	memset(&uf, 0, sizeof uf);

	/*
	* parse any of "1", "1.2", "1.2.3.4", "1.2.3-1.2.3.7", "1.2.3.0/29", ecc
	* store first ip in uf, other stuff in ul
	*/
	for (i = 0;;)
	{
		char *t = NULL;
		unsigned long l = strtoul(p, &t, 10);

		if (l > UCHAR_MAX)
			return parse_ip_invalid_too_high;

		else
		{
			un->b[i] = (unsigned char)l;
			p = t;
			sep = *p;
			if (sep == '.' && mode_flag != '/')
			{
				if (++i >= 4)
					return parse_ip_invalid_too_long;
			}
			else if (mode_flag == 0 && (sep == '-' || sep == '/'))
			{
				mode_flag = sep;
				i = 0;
				un = &ul;
			}
			else
				break;

			p += 1;
		}
	}

	/*
	* if the terminator is taken, set it to p later,
	* otherwise if the string is not 0-terminated fail
	*/
	if (sep != 0 && term == NULL)
		return parse_ip_invalid_separator;

	/*
	* set first and last ip according to parsed values
	* (ip is untouched on error)
	*/

	if (mode_flag == '-')
	{
		if (memcmp(&uf, &ul, sizeof uf) > 0)
			return parse_ip_invalid_range;

		ip->u2.ipv4l = ul.ip;
		ip->args = 2;
	}
	else if (mode_flag == '/')
	{
		if (ul.b[0] > 32)
			return parse_ip_invalid_cidr;

		if (ul.b[0])
		{
			ip->u2.ipv4l = uf.ip;
			if (ul.b[0] < 32)
			{
				uint32_t mask = htonl((1U << (32 - ul.b[0])) - 1);
				ip->u2.ipv4l |= mask;
				ip->u.ipv4l &= ~mask; // new for version 2
			}
			ip->args = 2;
			ip->plen = ul.b[0];
		}
		else
		{
			ip->u2.ipv4l = -1;
			ip->args = 1;
		}
	}
	else
	{
		ip->u2.ipv4l = uf.ip;
		ip->args = 1;
	}

	ip->u.ipv4l = uf.ip;

	assert(sep == *p);
	if (term)
		*term = p;

	return 0;
}

int parse_ip_address(char *ip_address, ip_range *ip, char **term)
/*
* Parse an IP address or range.  Return 0 on success or error_code.
*/
{
	assert(ip);
	memset(ip, 0, sizeof *ip);

	if (strchr(ip_address, ':') == NULL)
	{
		ip->ip = 4;
		return parse_ip4_address(ip_address, ip, term);
	}

	// strspn the address part to check for range
	char *e = ip_address;
	int ch;
	for (ch = *(unsigned char*)e; ch; ch = *(unsigned char*)++e)
		if (ch != ':' && !isxdigit(ch) && ch != '.')
			break;

	*e = 0;
	int rtc = inet_pton(AF_INET6, ip_address, &ip->u.ipv6[0]);
	*e = ch;
	if (rtc != 1)
		return parse_ip_invalid_ipv6_1;

	ip->ip = 6;
	ip->args = 1;
	if (ch == '/')
	{
		char *s = ++e;
		unsigned long l = strtoul(s, &e, 10);
		if (l > 128)
			return parse_ip_invalid_cidr;

		ip->args = 2;
		ip->plen = l;
		ch = *e;
		memcpy(ip->u2.ipv6, ip->u.ipv6, sizeof ip->u2.ipv6);
		if (l < 128)
		{
			int i;
			if (l > 64)
			{
				i = 1;
				l -= 64;
			}
			else
			{
				i = 0;
				ip->u2.ipv6l[1] = -1;
				ip->u.ipv6l[1] = 0;
			}

			uint64_t mask;
			if (l > 0)
			{
				int shift = 64 - l;
				assert(shift >= 0 && shift < 64);
				mask = htobe64((1LL << shift) - 1);
			}
			else
				mask = -1;
			ip->u2.ipv6l[i] |= mask;
			ip->u.ipv6l[i] &= ~mask;
		}
	}
	else if (ch == '-')
	{
		char *s = ++e;
		for (ch = *(unsigned char*)e; ch; ch = *(unsigned char*)++e)
			if (ch != ':' && !isxdigit(ch))
				break;
		*e = 0;
		rtc = inet_pton(AF_INET6, s, ip->u2.ipv6);
		*e = ch;
		if (rtc != 1)
			return parse_ip_invalid_ipv6_2;

		ip->args = 2;
	}
	else
		memcpy(ip->u2.ipv6, ip->u.ipv6, sizeof ip->u2.ipv6);

	if (ch != 0 && term == NULL)
		return parse_ip_invalid_separator;

	if ((ip->ip = check_ipv4_c_or_m(ip->u.ipv6)) == 4)
		ip->ipv4_mapped = 1;
	if (ip->ip != check_ipv4_c_or_m(ip->u2.ipv6))
	{
		ip->ip = 0;
		return parse_ip_invalid_range;
	}

	if (term)
		*term = e;

	return 0;
}

int range_ip(unsigned char const *ip1, unsigned char *ip2, unsigned sz)
/*
* Assuming that ip1 and ip2 are in the same network, prepare ip2 by
* filling the host parts with 1's and return the prefix length.
*/
{
	assert(ip1);
	assert(ip2);

	unsigned ndx;
	for (ndx = 0; ndx < sz; ++ndx)
		if (ip1[ndx] != ip2[ndx])
			break;

	if (ndx >= 16)
		return 8*sz;

	unsigned char x = ip1[ndx] ^ ip2[ndx];

	int n = 1; // count leading 0 bits in x
	if ((x >> 4) == 0) { n += 4; x <<= 4; }
	if ((x >> 6) == 0) { n += 2; x <<= 2; }
	n -= x >> 7;

	int plen = ndx * 8 + n;

	int bit = 1 << (7 - n);
	ip2[ndx] |= bit | (bit - 1); // right propagate bit

	if (++ndx < sz)
		memset(&ip2[ndx], 0xff, 16 - ndx);

	return plen;
}

static int
prefix_length(unsigned char const *first, unsigned char const *last, unsigned sz)
/*
* The size sz is at most 16.  Return 0 if first and last are not the
* boundaries of a CIDR range, otherwise return the prefix length.
*/
{
	assert(first);
	assert(last);

	unsigned char var[sz];
	memcpy(var, last, sz);
	int plen = range_ip(first, var, sz);
	if (memcmp(var, last, sz) != 0)
		return 0;

	int tail = 8*sz - plen;
	int ndx = sz - 1;
	while (tail > 8)
	{
		if (first[ndx] != 0)
			return 0;

		ndx -= 1;
		tail -= 8;
	}

	unsigned char mask = (1 << tail) - 1;
	if (first[ndx] & mask)
		return 0;

	return plen;
}

char const *snprint_range(char *buf, size_t bufsiz, ip_range const *ip)
{
	char *entrybuf = buf;
	if (ip->args == 0)
		snprintf(buf, bufsiz, "any");
	else
	{
		assert(ip->args >= 1);
		assert(ip->ip == 4 || ip->ip == 6);

		if (ip->ip == 4 && bufsiz >= INET_ADDRSTRLEN)
			inet_ntop(AF_INET, ip->u.ipv4, buf, bufsiz);
		else if (ip->ip == 6 && bufsiz >= INET6_ADDRSTRLEN)
			inet_ntop(AF_INET6, ip->u.ipv6, buf, bufsiz);

		if (ip->args > 1)
		{
			int plen = prefix_length(ip->u.ip_data,
				ip->u2.ip_data, ip->ip == 4? 4: 16);
			size_t at = strlen(buf);
			if (bufsiz > at + 2)
			{
				buf[at++] = plen? '/': '-';
				bufsiz -= at;
				buf += at;
				if (plen)
					snprintf(buf, bufsiz, "%d", plen);
				else if (ip->ip == 4 && bufsiz >= INET_ADDRSTRLEN)
					inet_ntop(AF_INET, &ip->u2.ipv4, buf, bufsiz);
				else if (ip->ip == 6 && bufsiz >= INET6_ADDRSTRLEN)
					inet_ntop(AF_INET6, ip->u2.ipv6, buf, bufsiz);
			}
		}
	}
	return entrybuf;
}


#if defined TEST_MAIN2
int main(int argc, char *argv[])
{
	union
	{
		ip_u u;
		unsigned char ch[sizeof(ip_u)];
	} var;

	printf("size of ip_u = %zu\n", sizeof var.u);
	for (int i = 1; i < argc; ++i)
	{
		int rtc = my_inet_pton(argv[i], &var.u);
		char buf[2*sizeof var + 1];
		for (unsigned j = 0; j < sizeof var; ++j)
			sprintf(&buf[2*j], "%02x", var.ch[j]);
		buf[2*sizeof var] = 0;
		printf("my_inet_pton(%s, %s) = %d\n", argv[i], buf, rtc);

		if (var.u.ip == 6)
		{
			unsigned char last[16];
			memcpy(last, var.u.u.ipv6, sizeof last);
			for (int plen = 124; plen > 0; plen -= 4)
			{
				unsigned char rng[16];
				memcpy(rng, var.u.u.ipv6, sizeof rng);
				first_in_range(rng, plen);
				if (memcmp(last, rng, sizeof rng))
				{
					char addr[INET6_ADDRSTRLEN];
					printf("at %d: %s/%d\n", plen,
						inet_ntop(AF_INET6, rng, addr, sizeof addr), plen);
					memcpy(last, rng, sizeof rng);
				}
			}
		}
	}

	return 0;
}
#endif // TEST_MAIN2

#if defined TEST_MAIN

int main(int argc, char *argv[])
{
	printf("sizeof ip_range %zu, %d arg(s)\n", sizeof(ip_range), argc - 1);

	char buf[INET_RANGESTRLEN];
	for (int i = 1; i < argc; ++i)
	{
		ip_range ip;
		int err = parse_ip_address(argv[i], &ip, NULL);
		printf("%d) %s -> %s (%s)\n", i, argv[i],
			snprint_range(buf, sizeof buf, &ip),
			parse_ip_invalid_what(err));
	}
}
#endif //defined TEST_MAIN
