/*
 * RageIRCd: an advanced Internet Relay Chat daemon (ircd).
 * (C) 2000-2005 the RageIRCd Development Team, all rights reserved.
 *
 * This software is free, licensed under the General Public License.
 * Please refer to doc/LICENSE and doc/README for further details.
 *
 * $Id: ssl.c,v 1.30.2.1 2004/12/07 03:05:41 pneumatus Exp $
 */

#include "setup.h"

#ifdef USE_OPENSSL

#include "config.h"
#include "struct.h"
#include "ssl.h"
#include "h.h"
#include "memory.h"
#include "common.h"
#include <string.h>
#include <unistd.h>

static SSL_CTX *ircd_ctx;

static int ssl_pem_passwd_cb(char *, int, int, void *);
static int init_ircd_ctx();
static char *ssl_get_init_error_message(int);
static void ssl_handle_fatal_error(int, int, int, char *);

static int ssl_pem_passwd_cb(char *buf, int bufsize, int rwflag, void *password)
{
#ifdef OS_CYGWIN
	int len = strlen(Internal.ssl_privkey_passwd);

	if (len) {
		strncpy(buf, Internal.ssl_privkey_passwd, bufsize - 1);
	}

	return len;
#else
	static int been_here = 0;
	static char saved_pass[BUFSIZE];
	char *pass;

	if (been_here) {
		strncpyzt(buf, saved_pass, BUFSIZE);
		return strlen(buf);
	}

	pass = getpass("SSL private key passphrase: ");
	if (BadPtr(pass)) {
		return 0;
	}

	strncpyzt(buf, pass, bufsize);
	strncpyzt(saved_pass, pass, BUFSIZE);
	been_here = 1;

	return strlen(buf);
#endif
}

static int init_ircd_ctx()
{
	if ((ircd_ctx = SSL_CTX_new(SSLv23_server_method())) == NULL) {
		return ERR_SSL_CTX_NEW;
	}

	SSL_CTX_set_default_passwd_cb(ircd_ctx, ssl_pem_passwd_cb);
	SSL_CTX_set_session_cache_mode(ircd_ctx, SSL_SESS_CACHE_OFF);

	if (SSL_CTX_use_certificate_file(ircd_ctx, ServerInfo->ssl_certificate, SSL_FILETYPE_PEM) <= 0) {
		SSL_CTX_free(ircd_ctx);
		return ERR_SSL_CTX_USE_CERT;
	}
	if (SSL_CTX_use_PrivateKey_file(ircd_ctx, ServerInfo->ssl_private_key, SSL_FILETYPE_PEM) <= 0) {
		SSL_CTX_free(ircd_ctx);
		return ERR_SSL_CTX_USE_KEY;
	}
	if (!SSL_CTX_check_private_key(ircd_ctx)) {
		SSL_CTX_free(ircd_ctx);
		return ERR_SSL_CTX_CHECK_KEY;
	}

	return ERR_SSL_SUCCESS;
}

static char *ssl_get_init_error_message(int error)
{
	switch (error) {
		case ERR_SSL_SUCCESS:
			return "success (no error)";
		case ERR_SSL_CTX_NEW:
			return "failed to create new SSL context";
		case ERR_SSL_CTX_USE_CERT:
			return "failed to open certificate file";
		case ERR_SSL_CTX_USE_KEY:
			return "failed to open private key file";
		case ERR_SSL_CTX_CHECK_KEY:
			return "authentication failure (certificate does not match private key)";
		default:
			return "unknown error occured in ssl_get_init_error_message()";
	}
	return NULL;
}

char *init_ssl()
{
	int i;

	SSL_load_error_strings();
	SSL_library_init();

	if ((i = init_ircd_ctx()) != ERR_SSL_SUCCESS) {
		return ssl_get_init_error_message(i);
	}

	if (Internal.verbose) {
		report(0, "Loaded SSL certificate and private key file.");
	}

	return NULL;
}

static char *ssl_get_errstr(int ssl_errno)
{
	char *ssl_err;
	switch (ssl_errno) {
		case SSL_ERROR_NONE:
			ssl_err = "No error";
			break;
		case SSL_ERROR_SSL:
			ssl_err = "Internal SSL or protocol error";
			break;
		case SSL_ERROR_WANT_CONNECT:
			ssl_err = "SSL functions requested connect()";
			break;
		case SSL_ERROR_WANT_READ:
			ssl_err = "SSL functions requested read()";
			break;
		case SSL_ERROR_WANT_WRITE:
			ssl_err = "SSL functions requested write()";
			break;
		case SSL_ERROR_WANT_X509_LOOKUP:
			ssl_err = "SSL requested an X509 lookup which didn't arrive";
			break;
		case SSL_ERROR_SYSCALL:
			ssl_err = "Underlying syscall error";
			break;
		case SSL_ERROR_ZERO_RETURN:
			ssl_err = "Underlying socket operation returned zero";
			break;
		default:
			ssl_err = "Unknown SSL Error";
			break;
	}
	return ssl_err;
}

static void ssl_handle_fatal_error(int ssl_errno, int what, int fd, char *host)
{
	char *ssl_err = ssl_get_errstr(ssl_errno);
	char *ssl_func, err_buf[512];

	switch (what) {
		case SAFE_SSL_READ:
			ssl_func = "SSL_read()";
			break;
		case SAFE_SSL_WRITE:
			ssl_func = "SSL_write()";
			break;
		case SAFE_SSL_ACCEPT:
			ssl_func = "SSL_accept()";
			break;
		default:
			ssl_func = "[Unknown SSL Function]";
			break;
	}

	if (errno) {
		ircsprintf(err_buf, "%s failed (%s) for %%s: %%s", ssl_func, ssl_err);
	}
	else {
		ircsprintf(err_buf, "%s failed (%s) for %%s: no error", ssl_func, ssl_err);
		errno = EIO;
	}
	report_error(fd, err_buf, host);
}

int safe_SSL_accept(SSL *ssl, int fd, char *host)
{
	int ssl_errno;

	if ((ssl_errno = SSL_accept(ssl)) > 0) {
		return 1;
	}
	switch ((ssl_errno = SSL_get_error(ssl, ssl_errno))) {
		case SSL_ERROR_SYSCALL:
			if (errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR) {
				case SSL_ERROR_WANT_READ:
				case SSL_ERROR_WANT_WRITE:
					Debug((DEBUG_DEBUG, "safe_SSL_accept(%d,[%s]) handshake in "
						"progress (%s)", fd, host, ssl_get_errstr(ssl_errno)));
					return 1;
			}
		default:
			Debug((DEBUG_DEBUG, "safe_SSL_read(%d,[%s]) fatal error", fd, host));
			ssl_handle_fatal_error(ssl_errno, SAFE_SSL_ACCEPT, fd, host);
			break;
	}
	return -1;
}

int safe_SSL_read(aClient *cptr, void *buf, size_t size)
{
	int len, ssl_errno;

	ASSERT(cptr != NULL);
	ASSERT(cptr->localClient != NULL);

	if ((len = SSL_read(cptr->localClient->ssl, buf, size)) > 0) {
		return len;
	}
	switch ((ssl_errno = SSL_get_error(cptr->localClient->ssl, len))) {
		case SSL_ERROR_SYSCALL:
			if (errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR) {
				case SSL_ERROR_WANT_READ:
					errno = EWOULDBLOCK;
					Debug((DEBUG_DEBUG, "safe_SSL_read(%d) returned EWOULDBLOCK (%s)",
						cptr->localClient->fd,
						(ssl_errno == SSL_ERROR_WANT_READ) ? "READ" : "WRITE"));
					break;
			}
		case SSL_ERROR_SSL:
			if (errno == EAGAIN) {
				break;
			}
		default:
			Debug((DEBUG_DEBUG, "safe_SSL_read(%d) fatal error", cptr->localClient->fd));
			ssl_handle_fatal_error(ssl_errno, SAFE_SSL_READ, cptr->localClient->fd,
				get_client_name(cptr, 1));
			break;
	}
	return -1;
}

int safe_SSL_write(aClient *cptr, const void *buf, size_t size)
{
	int len, ssl_errno;

	ASSERT(cptr != NULL);
	ASSERT(cptr->localClient != NULL);

	if ((len = SSL_write(cptr->localClient->ssl, buf, size)) > 0) {
		return len;
	}
	switch ((ssl_errno = SSL_get_error(cptr->localClient->ssl, len))) {
		case SSL_ERROR_SYSCALL:
			if (errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR) {
				errno = EWOULDBLOCK;
			}
			break;
		case SSL_ERROR_WANT_WRITE:
			errno = EWOULDBLOCK;
			break;
		case SSL_ERROR_SSL:
			if (errno == EAGAIN) {
				break;
			}
		default:
			Debug((DEBUG_DEBUG, "safe_SSL_write(%d) fatal error", cptr->localClient->fd));
			ssl_handle_fatal_error(ssl_errno, SAFE_SSL_WRITE, cptr->localClient->fd,
				get_client_name(cptr, 1));
			break;
	}
	return -1;
}

int SSL_smart_shutdown(SSL *ssl)
{
	int i, retval = 0;
	for (i = 0; i < 4; i++) {
		if ((retval = SSL_shutdown(ssl))) {
			break;
		}
	}
	return retval;
}

SSL *ssl_do_handshake(int fd, char *host)
{
	SSL *ssl;

	if ((ssl = SSL_new(ircd_ctx)) == NULL) {
		ircdlog(LOG_ERROR, "Failed to create new SSL object for socket %d", fd);
		return NULL;
	}

	SSL_set_fd(ssl, fd);

	if (safe_SSL_accept(ssl, fd, host) == -1) {
		SSL_set_shutdown(ssl, SSL_RECEIVED_SHUTDOWN);
		SSL_smart_shutdown(ssl);
		SSL_free(ssl);
		return NULL;
	}
	
	return ssl;
}

char *ssl_get_cipher_info(SSL *ssl)
{
	SSL_CIPHER *c = SSL_get_current_cipher(ssl);
	static char buf[400], *ver = NULL;
	int bits;

	switch (ssl->session->ssl_version) {
		case SSL2_VERSION:
			ver = "SSLv2";
			break;
		case SSL3_VERSION:
			ver = "SSLv3";
			break;
		case TLS1_VERSION:
			ver = "TLSv1";
			break;
		default:
			ver = "UNKNOWN";
	}
	SSL_CIPHER_get_bits(c, &bits);
	ircsprintf(buf, "%s-%s-%dbits", ver, SSL_get_cipher(ssl), bits);
        return buf;
}

#endif
