/*
** Copyright 2000 Double Precision, Inc.
** See COPYING for distribution information.
*/
#include	"config.h"
#include	"argparse.h"
#include	"spipe.h"
#include	"rfc1035/rfc1035.h"
#include	<stdio.h>
#include	<string.h>
#include	<stdlib.h>
#include	<netdb.h>
#if	HAVE_UNISTD_H
#include	<unistd.h>
#endif
#if	HAVE_FCNTL_H
#include	<fcntl.h>
#endif
#include	<errno.h>
#if	HAVE_SYS_TYPES_H
#include	<sys/types.h>
#endif
#include	<sys/socket.h>
#include	<arpa/inet.h>
#include	<openssl/ssl.h>
#include	<openssl/err.h>
#include	<sys/time.h>

static const char rcsid[]="$Id: starttls.c,v 1.7 2000/07/24 23:35:23 mrsam Exp $";

#ifndef NO_RSA
static RSA *rsa_callback(SSL *, int, int);
#endif

const char *ssl_cipher_list=0;
int session_timeout=0;
const char *dhcertfile=0;
const char *certfile=0;
const char *protocol=0;

int peer_verify_level=SSL_VERIFY_PEER;
		/* SSL_VERIFY_NONE */
		/* SSL_VERIFY_PEER */
		/* SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT */
const char *peer_cert_dir=0;
const char *ourcacert=0;
int okselfsignedcert=0;

/* Command-line options: */
const char *clienthost=0;
const char *clientport=0;

const char *server=0;
const char *localfd=0;
const char *remotefd=0;
const char *tcpd=0;
const char *peer_verify_domain=0;
int peer_domain_verified;

const char *printx509=0;
FILE *printx509_fp;

/* -------------------------------------------------------------------- */

static int ssl_verify_callback(int, X509_STORE_CTX *);

static void sslerror(const char *pfix)
{
char errmsg[256];

	ERR_error_string(ERR_get_error(), errmsg);
	fprintf(stderr, "starttls: %s: %s\n", pfix, errmsg);
}

SSL_CTX *create_tls(int isserver)
{
SSL_CTX *ctx;

	SSL_load_error_strings();
	SSLeay_add_ssl_algorithms();

	ctx=SSL_CTX_new(protocol && strcmp(protocol, "SSL2") == 0
							? SSLv2_method():
		protocol && strcmp(protocol, "SSL3") == 0 ? SSLv23_method():
		TLSv1_method());

	if (!ctx)
	{
		perror("SSL_CTX_NEW");
		return (0);
	}
	SSL_CTX_set_options(ctx, SSL_OP_ALL);

	if (ssl_cipher_list)
		SSL_CTX_set_cipher_list(ctx, ssl_cipher_list);
	SSL_CTX_set_timeout(ctx, session_timeout);

	if (isserver)
	{
#ifndef NO_RSA
		SSL_CTX_set_tmp_rsa_callback(ctx, rsa_callback);
#endif

#ifndef	NO_DH
		if (dhcertfile)
		{
		BIO	*bio;
		DH	*dh;
		int	cert_done=0;

			if ((bio=BIO_new_file(dhcertfile, "r")) != 0)
			{
				if ((dh=PEM_read_bio_DHparams(bio, NULL, NULL,
					NULL)) != 0)
				{
					SSL_CTX_set_tmp_dh(ctx, dh);
					cert_done=1;
					DH_free(dh);
				}
				BIO_free(bio);
			}
			else
				sslerror(dhcertfile);
			if (!cert_done)
			{
				fprintf(stderr, "starttls: DH init failed!\n");
				SSL_CTX_free(ctx);
				return (0);
			}
		}
#endif
	}
	SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_BOTH);
	if (certfile)
	{
		if(!SSL_CTX_use_certificate_file(ctx, certfile,
			SSL_FILETYPE_PEM))
		{
			sslerror(certfile);
			SSL_CTX_free(ctx);
			return (0);
		}
#ifndef	NO_RSA
		if(!SSL_CTX_use_RSAPrivateKey_file(ctx, certfile,
			SSL_FILETYPE_PEM))
#else
		if(!SSL_CTX_use_PrivateKey_file(ctx, certfile,
			SSL_FILETYPE_PEM))
#endif
		{
			sslerror(certfile);
			SSL_CTX_free(ctx);
			return (0);
		}
	}

	if (peer_cert_dir && ourcacert)
	{
		if ((!SSL_CTX_set_default_verify_paths(ctx))
			|| (!SSL_CTX_load_verify_locations(ctx, ourcacert,
				peer_cert_dir)))
		{
			sslerror(peer_cert_dir);
			SSL_CTX_free(ctx);
			return (0);
		}
		SSL_CTX_set_client_CA_list(ctx,
				SSL_load_client_CA_file(ourcacert));
	}
	SSL_CTX_set_verify(ctx, peer_verify_level, ssl_verify_callback);
	return (ctx);
}

static int ssl_verify_callback(int goodcert, X509_STORE_CTX *x509)
{
char txt[256];

	X509_NAME_oneline(X509_get_subject_name(x509->current_cert),
		txt, sizeof(txt));

	if (printx509)
	{
		fprintf(printx509_fp, "Subject: %s\n", txt);
		fflush(printx509_fp);
	}

	if (goodcert && peer_verify_domain)
	{
	char	*p;

		for (p=txt; (p=strtok(p, "/")) != 0; p=0)
		{
		int	pl;

			if (strncasecmp(p, "CN=", 3))
				continue;
			p += 3;
			if (*p == '*')
			{
			int	l;

				pl=strlen(++p);
				l=strlen(peer_verify_domain);

				if (pl <= l &&
					strcasecmp(peer_verify_domain+l-pl,
						p) == 0)
					peer_domain_verified=1;
			}
			else if (strcasecmp(peer_verify_domain, p) == 0)
				peer_domain_verified=1;
		}
	}

	if (!okselfsignedcert && !goodcert)
	{
		fprintf(stderr, "starttls: Bad certificate from %s\n", txt);
		return (0);
	}
	return 1;
}

static RSA *rsa_callback(SSL *s, int export, int keylength)
{
	fprintf(stderr, "Called rsa_callback\n");
	return (RSA_generate_key(keylength,RSA_F4,NULL,NULL));
}

SSL *connect_tls(SSL_CTX *ctx, int isserver, int fd)
{
SSL *ssl;

	if (!(ssl=SSL_new(ctx)))
	{
		sslerror("SSL_new");
		return (0);
	}

	SSL_set_fd(ssl, fd);
	peer_domain_verified=0;

	if (!peer_verify_domain)
		peer_domain_verified=1;

	if (printx509)
	{
		printx509_fp=fdopen(atoi(printx509), "w");
		if (!printx509_fp)
		{
			perror("fdopen");
			printx509=0;
		}
	}

	if (isserver)
	{
		SSL_set_accept_state(ssl);
		if (SSL_accept(ssl) <= 0)
		{
			sslerror("accept");
			SSL_set_shutdown(ssl,
				SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
			SSL_free(ssl);
			ERR_remove_state(0);
			if (printx509)
				fclose(printx509_fp);
			printx509=0;
			return (0);
		}
	}
	else
	{
		SSL_set_connect_state(ssl);
		if (SSL_connect(ssl) <= 0)
		{
			sslerror("accept");
			SSL_set_shutdown(ssl,
				SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
			SSL_free(ssl);
			ERR_remove_state(0);
			if (printx509)
				fclose(printx509_fp);
			printx509=0;
			return (0);
		}
	}
	if (printx509)
		fclose(printx509_fp);
	printx509=0;

	if (!peer_domain_verified)
	{
		SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
		SSL_free(ssl);
		ERR_remove_state(0);
		fprintf(stderr, "starttls: peer domain not verified.\n");
		return (0);
	}
	return (ssl);
}

void disconnect_tls(SSL_CTX *ctx, SSL *ssl)
{
	if (ssl)
	{
		SSL_set_shutdown(ssl,
			SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
		SSL_free(ssl);
		ERR_remove_state(0);
	}
	SSL_CTX_free(ctx);
}

#if 0 
static void mylog(const char *what, const char *buf, int cnt)
{
fprintf(stderr, "couriertls: %s: ", what);

	for (; cnt; ++buf, --cnt)
	{
		if (*buf == '\n')
		{
			fprintf(stderr, "\\n");
			continue;
		}
		if (*buf == '\r')
		{
			fprintf(stderr, "\\r");
			continue;
		}
		if ((int)(unsigned char)*buf < ' ')
		{
			fprintf(stderr, "\\x%02X", (int)(unsigned char)*buf);
			continue;
		}
		fprintf(stderr, "%c", *buf);
	}
	fprintf(stderr, "\n");
	fflush(stderr);
}
#endif

void transfer_tls(SSL *ssl, int sslfd, int stdinfd, int stdoutfd)
{
char	from_ssl_buf[BUFSIZ], to_ssl_buf[BUFSIZ];
char	*from_ssl_ptr=0, *to_ssl_ptr=0;
int	from_ssl_cnt, to_ssl_cnt;
fd_set	fdr, fdw;
int	maxfd=sslfd;
int	suppress_read;
int	suppress_write;

	if (fcntl(sslfd, F_SETFL, O_NONBLOCK))
	{
		perror("fcntl");
		return;
	}

	if (maxfd < stdinfd)	maxfd=stdinfd;
	if (maxfd < stdoutfd)	maxfd=stdoutfd;

	from_ssl_cnt=0;
	to_ssl_cnt=0;

	suppress_read=0;
	suppress_write=0;

	for (;;)
	{
		FD_ZERO(&fdr);
		FD_ZERO(&fdw);
		if (from_ssl_cnt)
			FD_SET(stdoutfd, &fdw);
		else if (!suppress_read)
		{
		int	n=SSL_pending(ssl);

			if (n > 0)
			{
				if (n >= sizeof(from_ssl_buf))
					n=sizeof(from_ssl_buf);

				n=SSL_read(ssl, from_ssl_buf, n);
				switch (SSL_get_error(ssl, n))	{
				case SSL_ERROR_NONE:
					if (n <= 0)
						return;
					break;
				case SSL_ERROR_WANT_WRITE:
					suppress_read=1;
					suppress_write=0;
					continue;
				case SSL_ERROR_WANT_READ:
					suppress_read=0;
					suppress_write=1;
					continue;
				case SSL_ERROR_WANT_X509_LOOKUP:
					continue;
				default:
					return;
				}
				from_ssl_cnt=n;
				from_ssl_ptr=from_ssl_buf;
				suppress_read=0;
				suppress_write=0;
				continue;
			}
			else
				FD_SET(sslfd, &fdr);
		}
		if (to_ssl_cnt)
		{
			if (!suppress_write)
				FD_SET(sslfd, &fdw);
		}
		else
			FD_SET(stdinfd, &fdr);

		if (select(maxfd+1, &fdr, &fdw, 0, 0) <= 0)
		{
			perror("select");
			return;
		}

		if (from_ssl_cnt && FD_ISSET(stdoutfd, &fdw))
		{
		int n=write(stdoutfd, from_ssl_ptr, from_ssl_cnt);

			if (n <= 0)	return;
			from_ssl_ptr += n;
			from_ssl_cnt -= n;
		}
		else if (!from_ssl_cnt && FD_ISSET(sslfd, &fdr))
		{
		int n=SSL_read(ssl, from_ssl_buf, sizeof(from_ssl_buf));

			switch (SSL_get_error(ssl, n))	{
			case SSL_ERROR_NONE:
				if (n <= 0)	return;
				break;
			case SSL_ERROR_WANT_WRITE:
				suppress_read=1;
				suppress_write=0;
				continue;
			case SSL_ERROR_WANT_READ:
				suppress_read=0;
				suppress_write=1;
				continue;
			case SSL_ERROR_WANT_X509_LOOKUP:
				continue;
			default:
				return;
			}

			from_ssl_ptr=from_ssl_buf;
			from_ssl_cnt=n;
			suppress_read=0;
			suppress_write=0;
		}

		if (to_ssl_cnt && FD_ISSET(sslfd, &fdw))
		{
		int n=SSL_write(ssl, to_ssl_ptr, to_ssl_cnt);

			switch (SSL_get_error(ssl, n))	{
			case SSL_ERROR_NONE:
				if (n <= 0)	return;
				break;
			case SSL_ERROR_WANT_WRITE:
				suppress_write=0;
				suppress_read=1;
				continue;
			case SSL_ERROR_WANT_READ:
				suppress_read=0;
				suppress_write=1;
				continue;
			case SSL_ERROR_WANT_X509_LOOKUP:
				continue;
			default:
				return;
			}

			to_ssl_ptr += n;
			to_ssl_cnt -= n;
			suppress_read=0;
			suppress_write=0;
		}
		else if (!to_ssl_cnt && FD_ISSET(stdinfd, &fdr))
		{
		int n=read(stdinfd, to_ssl_buf, sizeof(to_ssl_buf));

			if (n <= 0)	return;
			to_ssl_ptr=to_ssl_buf;
			to_ssl_cnt=n;
		}
	}
}

/* ----------------------------------------------------------------------- */

static void prepsocket(int sockfd)
{
#ifdef  SO_KEEPALIVE
	{
	int	dummy;

		dummy=1;

		if (setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE,
			(const char *)&dummy, sizeof(dummy)) < 0)
                {
                        perror("setsockopt");
                }
	}
#endif

#ifdef  SO_LINGER
	{
	struct linger l;

		l.l_onoff=0;
		l.l_linger=0;

		if (setsockopt(sockfd, SOL_SOCKET, SO_LINGER,
			(const char *)&l, sizeof(l)) < 0)
		{
			perror("setsockopt");
		}
	}
#endif
}

static void startclient(int argn, int argc, char **argv, int fd,
	int *stdin_fd, int *stdout_fd)
{
pid_t	p;
int	streampipe[2];

	if (localfd)
	{
		*stdin_fd= *stdout_fd= atoi(localfd);
		return;
	}

	if (argn >= argc)	return;		/* Interactive */

	if (s_pipe(streampipe))
	{
		perror("s_pipe");
		exit(1);
	}
	if ((p=fork()) == -1)
	{
		perror("fork");
		close(streampipe[0]);
		close(streampipe[1]);
		exit(1);
	}
	if (p == 0)
	{
	char **argvec;
	int n;

		close(fd);	/* Child process doesn't need it */
		close(0);
		dup(streampipe[1]);
		close(1);
		dup(streampipe[1]);
		close(streampipe[0]);
		close(streampipe[1]);

		argvec=malloc(sizeof(char *)*(argc-argn+1));
		if (!argvec)
		{
			perror("malloc");
			exit(1);
		}
		for (n=0; n<argc-argn; n++)
			argvec[n]=argv[argn+n];
		argvec[n]=0;
		execvp(argvec[0], argvec);
		perror(argvec[0]);
		exit(1);
	}
	close(streampipe[1]);

	*stdin_fd= *stdout_fd= streampipe[0];
}

static int connectremote(const char *host, const char *port)
{
int	fd;

RFC1035_ADDR addr;
int	af;
RFC1035_ADDR *addrs;
unsigned	naddrs, n;

RFC1035_NETADDR addrbuf;
const struct sockaddr *saddr;
int     saddrlen;
int	port_num;

	port_num=atoi(port);
	if (port_num <= 0)
	{
	struct servent *servent;

		servent=getservbyname(port, "tcp");

		if (!servent)
		{
			fprintf(stderr, "%s: invalid port.\n", port);
			return (-1);
		}
		port_num=servent->s_port;
	}
	else
		port_num=htons(port_num);

	if (rfc1035_aton(host, &addr) == 0) /* An explicit IP addr */
	{
		if ((addrs=malloc(sizeof(addr))) == 0)
		{
			perror("malloc");
			return (-1);
		}
		memcpy(addrs, &addr, sizeof(addr));
		naddrs=1;
	}
	else
	{
		if (rfc1035_a(&rfc1035_default_resolver, host, &addrs, &naddrs))
		{
			fprintf(stderr, "%s: not found.\n", host);
			return (-1);
		}
	}

        if ((fd=rfc1035_mksocket(SOCK_STREAM, 0, &af)) < 0)
        {
                perror("socket");
                return (-1);
        }

	for (n=0; n<naddrs; n++)
	{
		if (rfc1035_mkaddress(af, &addrbuf, addrs+n, port_num,
			&saddr, &saddrlen))	continue;

		if (connect(fd, saddr, saddrlen) == 0)
			break;
	}
	free(addrs);

	if (n >= naddrs)
	{
		close(fd);
		perror("connect");
		return (-1);
	}

	return (fd);
}

/* Connect to a remote server */

static int dossl(int fd, int argn, int argc, char **argv)
{
SSL_CTX *ctx;
SSL	*ssl;

int	stdin_fd, stdout_fd;

	prepsocket(fd);

	ctx=create_tls(server ? 1:0);
	if (ctx == 0)	return (1);

	ssl=connect_tls(ctx, server ? 1:0, fd);
	if (!ssl)
	{
		close(fd);
		return (1);
	}

	stdin_fd=0;
	stdout_fd=1;

	startclient(argn, argc, argv, fd, &stdin_fd, &stdout_fd);

	transfer_tls(ssl, fd, stdin_fd, stdout_fd);
	disconnect_tls(ctx, ssl);
	return (0);
}

static const char *safe_getenv(const char *n)
{
const char *v=getenv(n);

	if (!v)	v="";
	return (v);
}

int main(int argc, char **argv)
{
const char *s;
int	argn;
int	fd;

static struct args arginfo[] = {
	{ "host", &clienthost },
	{ "localfd", &localfd},
	{ "port", &clientport },
	{ "printx509", &printx509},
	{ "remotefd", &remotefd},
	{ "server", &server},
	{ "tcpd", &tcpd},
	{ "verify", &peer_verify_domain},
	{0}};

	s=safe_getenv("TLS_PROTOCOL");
	if (*s) protocol=s;

	s=safe_getenv("TLS_CIPHER_LIST");
	if (*s)	ssl_cipher_list=s;

	s=safe_getenv("TLS_TIMEOUT");
	session_timeout=atoi(s);

	s=safe_getenv("TLS_DHCERTFILE");
	if (*s)	dhcertfile=s;

	s=safe_getenv("TLS_CERTFILE");
	if (*s)	certfile=s;

	s=safe_getenv("TLS_VERIFYPEER");
	switch (*s)	{
	case 'n':
	case 'N':		/* NONE */
		peer_verify_level=SSL_VERIFY_NONE;
		break;
	case 'p':
	case 'P':		/* PEER */
		peer_verify_level=SSL_VERIFY_PEER;
		break;
	case 'r':
	case 'R':		/* REQUIREPEER */
		peer_verify_level=
			SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT;
		break;
	}

	s=safe_getenv("TLS_PEERCERTDIR");
	if (*s)	peer_cert_dir=s;

	s=safe_getenv("TLS_OURCACERT");

	if (*s)	ourcacert=s;

	s=safe_getenv("TLS_ALLOWSELFSIGNEDCERT");
	okselfsignedcert=atoi(s);

	argn=argparse(argc, argv, arginfo);

	if (tcpd)
	{
		close(1);
		dup(2);
		fd=0;
	}
	else if (remotefd)
		fd=atoi(remotefd);
	else if (clienthost && clientport)
		fd=connectremote(clienthost, clientport);
	else
	{
		fprintf(stderr, "%s: specify remote location.\n", argv[0]);
		return (1);
	}

	if (fd < 0)	return (1);

	return (dossl(fd, argn, argc, argv));
}

