#include "config.h"
#include <sys/types.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <time.h>
#include <dirent.h>
#include <pthread.h>
#include <sys/wait.h>
#include "global.h"
#include "client.h"
#include "command.h"
#include "log.h"

#define MAX_PASSWORD_DELAY_TIMER 60

typedef struct type_client {
	t_session *session;
	int delayed_remove;

	struct type_client *next;
} t_client;

typedef struct type_banned {
	t_ip_addr ip;
	int       timer;
	int       bantime;
	unsigned long connect_attempts;
	
	struct type_banned *next;
} t_banned;

static t_client *client_list[256];
static pthread_mutex_t client_mutex[256];
static t_banned *banlist;
static pthread_mutex_t ban_mutex;
static t_ipcounterlist *wrong_password_list;
static pthread_mutex_t pwd_mutex;
static int delay_timer = 0;

/* Initialize this module.
 */
void init_client_module(void) {
	int i;

	for (i = 0; i < 256; i++) {
		client_list[i] = NULL;
		pthread_mutex_init(&client_mutex[i], NULL);
	}
	banlist = NULL;
	pthread_mutex_init(&ban_mutex, NULL);
	wrong_password_list = NULL;
	pthread_mutex_init(&pwd_mutex, NULL);
}

/* Add the session record of a client to the client_list.
 */
int add_client(t_session *session) {
	t_client *new;
	unsigned char i;

	if ((new = (t_client*)malloc(sizeof(t_client))) != NULL) {
		new->session = session;
		new->delayed_remove = TIMER_OFF;

		i = index_by_ip(&(session->ip_address));

		pthread_mutex_lock(&client_mutex[i]);

		new->next = client_list[i];
		client_list[i] = new;

		pthread_mutex_unlock(&client_mutex[i]);

		return 0;
	} else {
		return -1;
	}
}

/* Remember the clientrecord for flooding prevention
 */
int mark_client_for_removal(t_session *session, int delay) {
	t_client *list;
	unsigned char i;
	int result = 0;

	i = index_by_ip(&(session->ip_address));

	pthread_mutex_lock(&client_mutex[i]);

	list = client_list[i];
	while (list != NULL) {
		if (list->session == session) {
			close_socket(list->session);
			list->delayed_remove = delay - 1;
			result = 1;
			break;
		}
		list = list->next;
	}

	pthread_mutex_unlock(&client_mutex[i]);

	return result;
}

/* Check the delayed_remove timers and remove client
 * when timer has reached 0
 */
void check_delayed_remove_timers(t_config *config) {
	t_client *client, *prev = NULL, *next;
	int i;

	if (config->reconnect_delay <= 0) {
		return;
	}

	for (i = 0; i < 256; i++) {
		pthread_mutex_lock(&client_mutex[i]);

		client = client_list[i];
		while (client != NULL) {
			next = client->next;
			switch (client->delayed_remove) {
				case TIMER_OFF:
					prev = client;
					break;
				case 0:
					free(client->session);
					free(client);
					if (prev == NULL) {
						client_list[i] = next;
					} else {
						prev->next = next;
					}
					break;
				default:
					client->delayed_remove--;
					prev = client;
			}
			client = next;
		}

		pthread_mutex_unlock(&client_mutex[i]);
	}
}

/* Remove a client from the client_list.
 */
int remove_client(t_session *session, bool free_session) {
	t_client *to_be_removed = NULL, *list;
	unsigned char i;

	i = index_by_ip(&(session->ip_address));

	pthread_mutex_lock(&client_mutex[i]);

	if (client_list[i] != NULL) {
		if (client_list[i]->session == session) {
			to_be_removed = client_list[i];
			client_list[i] = client_list[i]->next;
		} else {
			list = client_list[i];
			while (list->next != NULL) {
				if (list->next->session == session) {
					to_be_removed = list->next;
					list->next = to_be_removed->next;
					break;
				}
				list = list->next;
			}
		}
	} else {
		log_error(session, "Client record not found.");
	}

	pthread_mutex_unlock(&client_mutex[i]);

	if (to_be_removed != NULL) {
		if (free_session) {
			close_socket(to_be_removed->session);
			free(to_be_removed->session);
		}
		free(to_be_removed);

		return 1;
	} else {
		return 0;
	}
}

/* Check whether to allow or deny a new connection.
 */
int connection_allowed(t_ip_addr *ip, int maxperip, int maxtotal) {
	bool banned = false;
	t_banned *ban;
	int perip = 0, total = 0, i;
	t_client *client;

	pthread_mutex_lock(&ban_mutex);

	ban = banlist;
	while (ban != NULL) {
		if (same_ip(&(ban->ip), ip)) {
			ban->connect_attempts++;
			banned = true;
			break;
		}
		ban = ban->next;
	}

	pthread_mutex_unlock(&ban_mutex);

	if (banned) {
		return ca_BANNED;
	}

	for (i = 0; i < 256; i++) {
		pthread_mutex_lock(&client_mutex[i]);

		client = client_list[i];
		while (client != NULL) {
			if (same_ip(&(client->session->ip_address), ip)) {
				perip++;
			}
			if (client->delayed_remove == TIMER_OFF) {
				total++;
			}
			client = client->next;
		}

		pthread_mutex_unlock(&client_mutex[i]);
	}

	if (perip < maxperip) {
		if (total < maxtotal) {
			return total;
		} else {
			return ca_TOOMUCH_TOTAL;
		}
	} else {
		return ca_TOOMUCH_PERIP;
	}
}

/* Disconnect all connected clients.
 */
int disconnect_clients(t_config *config) {
	t_client *client;
	t_directory *dir;
	int max_wait = 30, i, kicked = 0;

	for (i = 0; i < 256; i++) {
		pthread_mutex_lock(&client_mutex[i]);

		client = client_list[i];
		while (client != NULL) {
			client->session->force_quit = true;
			client = client->next;
			kicked++;
		}

		pthread_mutex_unlock(&client_mutex[i]);
	}

	for (i = 0; i < 256; i++) {
		while ((client_list[i] != NULL) && (max_wait-- > 0)) {
			usleep(100000);
		}
	}

	dir = config->directory;
	while (dir != NULL) {
		dir->nr_of_clients = 0;
		dir = dir->next;
	}
	
	return kicked;
}

/* Kick an IP address.
 */
int kick_ip(t_ip_addr *ip) {
	t_client *client;
	int result = 0;
	unsigned char i;

	i = index_by_ip(ip);

	pthread_mutex_lock(&client_mutex[i]);

	client = client_list[i];
	while (client != NULL) {
		if (same_ip(&(client->session->ip_address), ip)) {
			client->session->force_quit = true;
			result++;
		}
		client = client->next;
	}

	pthread_mutex_unlock(&client_mutex[i]);

	return result;
}

/* Check if the client is flooding the server with requests
 */
bool client_is_flooding(t_session *session) {
	time_t time_passed;

	time_passed = time(NULL) - session->flooding_timer + 1;

	return ((session->kept_alive * session->config->flooding_time) > (session->config->flooding_count * time_passed));
}

/* Disconnect a client.
 */
int kick_client(int id) {
	t_client *client;
	int i, result = 0;

	for (i = 0; i < 256; i++) {
		pthread_mutex_lock(&client_mutex[i]);

		client = client_list[i];
		while (client != NULL) {
			if (client->session->client_id == id) {
				client->session->force_quit = true;
				result = 1;
				break;
			}
			client = client->next;
		}

		pthread_mutex_unlock(&client_mutex[i]);
	}

	return result;
}

/* IP ban functions
 */
int ban_ip(t_ip_addr *ip, int timer, bool kick_on_ban) {
	int retval = 0;
	t_banned *ban;
	bool new_ip = true;

	pthread_mutex_lock(&ban_mutex);

	ban = banlist;
	while (ban != NULL) {
		if (same_ip(&(ban->ip), ip)) {
			ban->timer = timer;
			ban->bantime = timer;
			new_ip = false;
			break;
		}
		ban = ban->next;
	}

	if (new_ip) {
		if ((ban = (t_banned*)malloc(sizeof(t_banned))) != NULL) {
			copy_ip(&(ban->ip), ip);
			ban->timer = timer;
			ban->bantime = timer;
			ban->connect_attempts = 0;
			ban->next = banlist;
			banlist = ban;

#ifdef HAVE_COMMAND
			increment_counter(COUNTER_BAN);
#endif

			retval = 1;
		} else {
			retval = -1;
		}
	}

	pthread_mutex_unlock(&ban_mutex);

	if (kick_on_ban && new_ip) {
		retval = kick_ip(ip);
	}

	return retval;
}

/* Reset the timer of a banned IP address.
 */
int reban_ip(t_ip_addr *ip) {
	t_banned *ban;
	int result = 0;

	pthread_mutex_lock(&ban_mutex);

	ban = banlist;
	while (ban != NULL) {
		if (same_ip(&(ban->ip), ip)) {
			ban->timer = ban->bantime;
			result = 1;
			break;
		}
		ban = ban->next;
	}

	pthread_mutex_unlock(&ban_mutex);
	
	return result;
}

/* Check the timers of the banlist.
 */
void check_ban_list(t_config *config) {
	t_banned *ban, *prev = NULL, *next;

	pthread_mutex_lock(&ban_mutex);

	ban = banlist;
	while (ban != NULL) {
		next = ban->next;
		switch (ban->timer) {
			case TIMER_OFF:
				prev = ban;
				break;
			case 0:
				if (prev == NULL) {
					banlist = next;
				} else {
					prev->next = next;
				}
				log_unban(config->system_logfile, &(ban->ip), ban->connect_attempts);
				free(ban);
				break;
			default:
				ban->timer--;
				prev = ban;
		}
		ban = next;
	}
	
	pthread_mutex_unlock(&ban_mutex);
}

/* Unban an IP address.
 */
int unban_ip(t_ip_addr *ip) {
	t_ip_addr any;
	bool any_ip;
	t_banned *ban, *prev = NULL, *next;
	int result = 0;

	/* Unban all?
	 */
	default_ipv4(&any);
	any_ip = same_ip(ip, &any);
#ifdef HAVE_IPV6
	if (any_ip == false) {
		default_ipv6(&any);
		any_ip = same_ip(ip, &any);
	}
#endif

	pthread_mutex_lock(&ban_mutex);
	
	ban = banlist;
	while (ban != NULL) {
		next = ban->next;
		if (same_ip(&(ban->ip), ip) || any_ip) {
			if (prev == NULL) {
				banlist = ban->next;
			} else {
				prev->next = ban->next;
			}
			free(ban);
			result++;

			if (any_ip == false) {
				break;
			}
		} else {
			prev = ban;
		}
		ban = next;
	}

	pthread_mutex_unlock(&ban_mutex);

	return result;
}

int register_wrong_password(t_session *session) {
	t_ipcounterlist *item;

	if (session->config->ban_on_wrong_password == 0) {
		return 0;
	}

	pthread_mutex_lock(&pwd_mutex);

	item = wrong_password_list;
	while (item != NULL) {
		if (same_ip(&(item->ip), &(session->ip_address))) {
			if (++(item->count) >= session->config->max_wrong_passwords) {
				if (ip_allowed(&(session->ip_address), session->config->banlist_mask)) {
					ban_ip(&(session->ip_address), session->config->ban_on_wrong_password, session->config->kick_on_ban);
					session->keep_alive = false;
					log_system(session, "Client banned because too many wrong passwords");
				}
			}

			pthread_mutex_unlock(&pwd_mutex);
			return 0;
		}
		item = item->next;
	}

	if ((item = (t_ipcounterlist*)malloc(sizeof(t_ipcounterlist))) == NULL) {
		pthread_mutex_unlock(&pwd_mutex);
		return -1;
	}

	copy_ip(&(item->ip), &(session->ip_address));
	item->count = 1;
	item->next = wrong_password_list;
	wrong_password_list = item;

	pthread_mutex_unlock(&pwd_mutex);

	return 0;
}

void remove_wrong_password_list(t_config *config) {
	t_ipcounterlist *item, *remove;

	if ((config->ban_on_wrong_password == 0) || (wrong_password_list == NULL)) {
		return;
	} else if (++delay_timer < MAX_PASSWORD_DELAY_TIMER) {
		return;
	}

	pthread_mutex_lock(&pwd_mutex);

	item = wrong_password_list;
	wrong_password_list = NULL;

	pthread_mutex_unlock(&pwd_mutex);

	while (item != NULL) {
		remove = item;
		item = item->next;

		free(remove);
	}

	delay_timer = 0;
}

#ifdef HAVE_COMMAND
/* Print the list of current connections.
 */
void print_client_list(FILE *fp) {
	t_client *client;
	char ip_address[MAX_IP_STR_LEN];
	int i, count = 0;

	for (i = 0; i < 256; i++) {
		pthread_mutex_lock(&client_mutex[i]);

		client = client_list[i];
		while (client != NULL) {
			fprintf(fp, "  Client ID   : %d\n", client->session->client_id);
#ifdef DEBUG
			if (client->session->status != NULL) {
				fprintf(fp, "  Status      : %s\n", client->session->status);
			} else {
				fprintf(fp, "  Status      : unknown\n");
			}
#endif
			if (inet_ntop(client->session->ip_address.family, &(client->session->ip_address.value), ip_address, MAX_IP_STR_LEN) != NULL) {
				fprintf(fp, "  IP-address  : %s\n", ip_address);
			}
			fprintf(fp, "  Socket      : %d\n", client->session->client_socket);
			if (client->session->remote_user != NULL) {
				fprintf(fp, "  Remote user : %s\n", client->session->remote_user);
			}
			fprintf(fp, "  Kept alive  : %d\n\n", client->session->kept_alive);
			client = client->next;
			count++;
		}

		pthread_mutex_unlock(&client_mutex[i]);
	}

	fprintf(fp, "  Total: %d clients\n", count);
}

int number_of_clients(void) {
	int result = 0, i;
	t_client *client;

	for (i = 0; i < 256; i++) {
		pthread_mutex_lock(&client_mutex[i]);

		client = client_list[i];
		while (client != NULL) {
			result++;
			client = client->next;
		}

		pthread_mutex_unlock(&client_mutex[i]);
	}

	return result;
}

/* Print the list of banned IP addresses.
 */
void print_ban_list(FILE *fp) {
	t_banned *ban;
	char ip_address[MAX_IP_STR_LEN];
	int count = 0;

	pthread_mutex_lock(&ban_mutex);

	ban = banlist;
	while (ban != NULL) {
		if (inet_ntop(ban->ip.family, &(ban->ip.value), ip_address, MAX_IP_STR_LEN) != NULL) {
			fprintf(fp, "  IP-address  : %s\n", ip_address);
		}
		fprintf(fp, "  seconds left: %d\n\n", ban->timer);
		ban = ban->next;
		count++;
	}
	fprintf(fp, "  Total: %d bans\n", count);

	pthread_mutex_unlock(&ban_mutex);
}

int number_of_bans(void) {
	t_banned *ban;
	int result = 0;

	pthread_mutex_lock(&ban_mutex);
	
	ban = banlist;
	while (ban != NULL) {
		result++;
		ban = ban->next;
	}

	pthread_mutex_unlock(&ban_mutex);

	return result;
}
#endif
