/* Schedwi
   Copyright (C) 2007 Herve Quatremain

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU Library General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
   USA.
*/

/* schedwiloadcert.c -- Add or update an agent host in the database */

#include <schedwi.h>

#if HAVE_SYS_TYPES_H
#include <sys/types.h>
#endif

#if STDC_HEADERS
#include <stdlib.h>
#include <string.h>
#else
#if HAVE_STDLIB_H
#include <stdlib.h>
#endif
#if HAVE_STRING_H
#include <string.h>
#endif
#endif

#if HAVE_STDIO_H
#include <stdio.h>
#endif

#if HAVE_SYS_STAT_H
#include <sys/stat.h>
#endif

#if HAVE_UNISTD_H
#include <unistd.h>
#endif

#if HAVE_TIME_H
#include <time.h>
#endif

#if HAVE_GETOPT_H
#include <getopt.h>
#endif

#ifdef HAVE_LOCALE_H
#include <locale.h>
#endif

#if HAVE_ERRNO_H
#include <errno.h>
#endif
#ifndef errno
extern int errno;
#endif

#if HAVE_ASSERT_H
#include <assert.h>
#endif

#include <gnutls/gnutls.h>
#include <gnutls/x509.h>

#include <sql_common.h>
#include <utils.h>
#include <conf.h>
#include <conf_srv.h>
#include <net_utils.h>
#include <net_utils_sock.h>
#include <net_utils_ssl.h>
#include <sql_hosts.h>
#include <lib_functions.h>

static const char *configuration_file = NULL;


/*
 * Print a help message to stderr
 */
void
help (const char * prog_name)
{
	int i;
	char * msg[] = {
N_("Add or update a Schedwi agent certificate in the database."),
N_("This program must be run on the same host as the Schedwi server."),
"",
#if HAVE_GETOPT_LONG
N_("  -c, --config=FILE  use the configuration file FILE rather than"),
N_("                     the default one"),
N_("  -a, --agent=NAME   agent host name.  If not set, the host name in the"),
N_("                     certificate (in the Alternate Name or in the"),
N_("                     Common Name field) will be used"),
N_("  -p, --port=PORT    agent TCP port"),
N_("  -d, --descr=TEXT   also add an agent description in the database"),
N_("  -f, --force        force the database upload even if there is a"),
N_("                     mismatch between the provided host name and the one"),
N_("                     in the certificate file (in the Alternate Name or"),
N_("                     in the Common Name field)"),
N_("  -h, --help         display this help and exit"),
N_("  -V, --version      print version information and exit"),
#else /* HAVE_GETOPT_LONG */
N_("  -c FILE            use the configuration file FILE rather than"),
N_("                     the default one"),
N_("  -a HOSTNAME        agent host name.  If not set, the host name in the"),
N_("                     certificate (in the Alternate Name or in the"),
N_("                     Common Name field) will be used"),
N_("  -p PORT            agent TCP port"),
N_("  -d DESCRIPTION     also add an agent description in the database"),
N_("  -f                 force the database upload even if there is a"),
N_("                     mismatch between the provided host name and the one"),
N_("                     in the certificate file (in the Alternate Name or"),
N_("                     in the Common Name field)"),
N_("  -h                 display this help and exit"),
N_("  -V                 print version information and exit"),
#endif /* ! HAVE_GETOPT_LONG */
"",
"~",
N_("Exit status is 1 on error (0 means no error)."),
NULL
	};

#if HAVE_ASSERT_H
	assert (prog_name != NULL && configuration_file != NULL);
#endif

	fprintf (stderr, _("Usage: %s [OPTION]... FILE\n"), prog_name);
						
	for (i = 0; msg[i] != NULL; i++) {
		if (msg[i][0] == '~') {
			fputs (	_("The default configuration file is "),
				stderr);
			fputs (configuration_file, stderr);
		}
		else {
			fputs ((msg[i][0] == '\0')? msg[i]: _(msg[i]), stderr);
		}
		fputc ('\n', stderr);
	}

	fputc ('\n', stderr);
	fputs (_("Report bugs to "), stderr);
	fputs (PACKAGE_BUGREPORT, stderr);
	fputc ('\n', stderr);
}


/*
 * Retrieve the certificate and the agent host name from the certificate file
 *
 * Return:
 *   0 --> No error.  cert and client_name are set and must be freed by the
 *         caller by free().  If not NULL cert_len is set with the len
 *         of the certificate (in cert)
 *  -1 --> Error.  An error message has been printed on stderr
 */
static int
read_certificat (	const char *filename,
			char **cert, unsigned long int *cert_len,
			char **client_name)
{
	struct stat stat_str;
	FILE *f;
	char *file_content, *agent_name;
	time_t now, t;
	int ret;
	unsigned int i;
	char dnsname[512];
	size_t dnsnamesize;
	char found_dnsname;
	gnutls_x509_crt_t cert_obj;
	gnutls_datum_t cert_data;

#if HAVE_ASSERT_H
	assert (filename != NULL && cert != NULL && client_name != NULL);
#endif

	/* Retrieve the file size */
	if (stat (filename, &stat_str) != 0) {
		perror (filename);
		return -1;
	}

	/* Create a buffer for reading the whole file */
	file_content = (char *) malloc (stat_str.st_size + 1);
	if (file_content == NULL) {
		fputs (_("Memory allocation error\n"), stderr);
		return -1;
	}

	/* Open and read the whole file */
	f = fopen (filename, "r");
	if (f == NULL) {
		perror (filename);
		free (file_content);
		return -1;
	}

	if (fread (file_content, 1, stat_str.st_size, f) != stat_str.st_size) {
		perror (filename);
		fclose (f);
		free (file_content);
		return -1;
	}
	fclose (f);

	file_content[stat_str.st_size] = '\0';

	/* Initialize the gnutls_x509_crt object */
	ret = gnutls_x509_crt_init (&cert_obj);
	if (ret != 0) {
		fprintf (stderr, _("Cannot initialize the certificate: %s\n"),
				gnutls_strerror (ret));
		free (file_content);
		return -1;
	}

	/*
	 * Try to convert the provided certificate (a string) to a
	 * gnutls_x509_crt object.  Try the PEM and then the DER format
	 */
	cert_data.data = file_content;
	cert_data.size = stat_str.st_size;
	ret = gnutls_x509_crt_import ( cert_obj, &cert_data,
					GNUTLS_X509_FMT_PEM);
	if (	   ret != 0
		&& gnutls_x509_crt_import (	cert_obj, &cert_data,
						GNUTLS_X509_FMT_DER) != 0)
	{
		fprintf (stderr,
			_("Cannot import the provided certificate: %s\n"),
			gnutls_strerror (ret));
		free (file_content);
		gnutls_x509_crt_deinit (cert_obj);
		return -1;
	}

	/* Check the dates */
	now = time (NULL);

	t = gnutls_x509_crt_get_activation_time (cert_obj);
	if (t > now) {
		fprintf (stderr,
			_("The provided certificate is not yet activated: %s"),
			ctime (&t));
		free (file_content);
		gnutls_x509_crt_deinit (cert_obj);
		return -1;
	}

	t = gnutls_x509_crt_get_expiration_time (cert_obj);
	if (t < now) {
		fprintf (stderr, _("The provided certificate has expired: %s"),
				ctime (&t));
		free (file_content);
		gnutls_x509_crt_deinit (cert_obj);
		return -1;
	}
	
	if (t < now + 30 * 24 * 60 * 60) {
		fprintf (stderr,
_("Warning: The provided certificate will expire in less than a month: %s"),
			ctime (&t));
	}

	/*
	 * Retrieve the agent name first from the Alternative name field
	 * in the certificate (X509v3 Certificate Extensions) then, if not
	 * found, from the Common Name.
	 */
	ret = 0;
	found_dnsname = 0;
	for (i = 0; ret >= 0; i++) {
		dnsnamesize = sizeof (dnsname);
		ret = gnutls_x509_crt_get_subject_alt_name (cert_obj, i,
						dnsname, &dnsnamesize, NULL);
		if (ret == GNUTLS_SAN_DNSNAME) {
			found_dnsname = 1;
			break;
		}
	}

	if (found_dnsname == 0) {
		dnsnamesize = sizeof (dnsname);
		ret = gnutls_x509_crt_get_dn_by_oid (cert_obj,
						GNUTLS_OID_X520_COMMON_NAME,
						0, 0,
						dnsname, &dnsnamesize);
		if (ret < 0) {
			fprintf (stderr,
_("Cannot retrieve the remote client name from the provided certificate: %s\n"),
				gnutls_strerror (ret));
			free (file_content);
			gnutls_x509_crt_deinit (cert_obj);
			return -1;
		}
	}

	agent_name = (char *) malloc (schedwi_strlen (dnsname) + 1);
	if (agent_name == NULL) {
		fputs (_("Memory allocation error\n"), stderr);
		free (file_content);
		gnutls_x509_crt_deinit (cert_obj);
		return -1;
	}
	strcpy (agent_name, dnsname);

	gnutls_x509_crt_deinit (cert_obj);

	*client_name = agent_name;
	*cert        = file_content;
	if (cert_len != NULL) {
		*cert_len = stat_str.st_size;
	}
	return 0;
}


/*
 * Error callback function for the sql_host_replace() function
 */
static void
sql_host_replace_error (void *data, const char *msg, unsigned int err_code)
{
	if (msg != NULL) {
		fputs (msg, stderr);
		fputc ('\n', stderr);
	}
	else {
		fputs (
		_("Database error while saving the client host details\n"),
			stderr);
	}
}


/*
 * Save the agent host details in the database
 *
 * Return:
 *   0 --> No error
 *  -1 --> Error.  An error message has been printed on stderr
 */
static int
certificate_to_database (	const char *filename,
				const char *hostname,
				const char *port_number,
				const char *description,
				char force)
{
	const char *port;
	char *cert, *client_name;
	unsigned int ret;

#if HAVE_ASSERT_H
	assert (filename != NULL);
#endif

	/* Get the port number */
	if (port_number == NULL) {
		port = SCHEDWI_DEFAULT_AGTPORT;
	}
	else {
		/* Check the port */
		if (get_port_number (port_number) == 0) {
			fprintf (stderr,
				_("Cannot get the port number for `%s'\n"),
				port_number);
			return -1;
		}
		port = port_number;
	}

	/* Retrieve the certificate */
	if (read_certificat (filename, &cert, NULL, &client_name) != 0) {
		return -1;
	}

	/*
	 * Compare the provided agent host name with the one
	 * retrieved from the certificate
	 */
	if (	   force == 0 && hostname != NULL
		&& schedwi_strcasecmp (hostname, client_name) != 0)
	{
		fprintf (stderr,
	_("Hostname `%s' is different from the one in certicate (`%s')\n"),
			hostname, client_name);
		free (cert);
		free (client_name);
		return -1;
	}

	/* Update the database */
	ret = sql_host_replace ((hostname != NULL) ? hostname: client_name,
				port, 1, cert, description,
				sql_host_replace_error, NULL);
	free (cert);
	free (client_name);
	return (ret == 0) ? 0: -1;
}


/*
 * Main function
 *
 * The exit code is 1 in case of error or 0 if killed (SIGTERM)
 */
int
main (int argc, char **argv)
{
	const char *server_key, *server_crt;
	const char *prog_name;
	int ret;
	char *err_msg;
	const char *agent_host, *port, *descr, *certificate_file;
	char force;

#if HAVE_GETOPT_LONG
	int option_index;
	struct option long_options[] =
	{
		{"help",       0, 0, 'h'},
		{"version",    0, 0, 'V'},
		{"config",     1, 0, 'c'},
		{"agent",      1, 0, 'a'},
		{"port",       1, 0, 'p'},
		{"descr",      1, 0, 'd'},
		{"force",      0, 0, 'f'},
		{0, 0, 0, 0}
	};
#endif

#if HAVE_SETLOCALE
	setlocale (LC_ALL, "");
#endif
#if HAVE_BINDTEXTDOMAIN
	bindtextdomain (PACKAGE, LOCALEDIR);
#endif
#if HAVE_TEXTDOMAIN
	textdomain (PACKAGE);
#endif


	/* Set default values for options */
	configuration_file = SCHEDWI_DEFAULT_CONFFILE_SRV;
	agent_host = port = descr = NULL;
	force = 0;

	prog_name = base_name (argv[0]);

	/* Parse options */
	while (1) {
#if HAVE_GETOPT_LONG
		option_index = 0;
		ret = getopt_long (argc, argv, "hVc:a:p:d:f",
					long_options, &option_index);
#else
		ret = getopt (argc, argv, "hVc:a:p:d:f");
#endif

		if (ret == -1) {
			break;
		}

		switch (ret) {
			case 'h':
				help (prog_name);
				return 0;
			case 'V':
				version (prog_name);
				return 0;
			case 'c':
				configuration_file = optarg;
				break;
			case 'a':
				agent_host = optarg;
				break;
			case 'p':
				port = optarg;
				break;
			case 'd':
				descr = optarg;
				break;
			case 'f':
				force = 1;
				break;
			default:
				help (prog_name);
				return 1;
		}
	}

	/* Agent host name missing */
	if (optind >= argc) {
		fputs (_("Certificate file name required\n"), stderr);
		help (prog_name);
		return 1;
	}

	/* Too many parameters */
	if (optind + 1 != argc) {
		fputs (_("Too many parameters\n"), stderr);
		help (prog_name);
		return 1;
	}

	certificate_file = argv[optind];

	/*
	 * Read the configuration file
	 */
	ret = conf_init_srv (configuration_file);
	switch (ret) {
		case -1:
			fputs (_("Memory allocation error\n"), stderr);
			break;

		case -2:
			perror (configuration_file);
			break;
	}
	if (ret != 0) {
		return 1;
	}

	/*
	 * Initialize the network (only needed for the SSL init)
	 */
	ret = conf_get_param_string ("SSLServerCertificateFile", &server_crt);
	ret += conf_get_param_string (	"SSLServerCertificateKeyFile",
					&server_key);
#if HAVE_ASSERT_H
	assert (ret == 0);
#endif

	if (net_init (server_crt, server_key) != 0) {
		conf_destroy_srv ();
		return 1;
	}

	/*
	 * Database connection
	 */
	err_msg = NULL;
	if (begin_mysql (&err_msg) == NULL) {
		if (err_msg != NULL) {
			fprintf (stderr,
				_("Failed to connect to database: %s\n"),
				err_msg);
			free (err_msg);
		}
		else {
			fputs (_("Failed to connect to database\n"), stderr);
		}
		conf_destroy_srv ();
		net_destroy ();
		return 1;
	}

	/* Add/Update the new agent */
	if (certificate_to_database (	certificate_file, agent_host,
					port, descr, force) != 0)
	{
		end_mysql ();
		conf_destroy_srv ();
		net_destroy ();
		return 1;
	}

	return 0;
}

/*------------------------======= End Of File =======------------------------*/
