/* $Id: starttls.C,v 1.17 2005/08/17 04:05:46 dm Exp $ */

/*
 *
 * Copyright (C) 2005 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "asmtpd.h"
#include "async_ssl.h"

static bool ssl_initialized;
static u_int64_t ssl_initno;
static SSL_CTX *ctx;
static bssstr ssl_randfile;

static void
ssl_shutdown ()
{
  if (ssl_randfile) {
    RAND_write_file (ssl_randfile);
    ssl_randfile = NULL;
  }
}

static u_int64_t ssl_tmpkey_initno;
static RSA *ssl_tmpkey_key;
timecb_t *ssl_tmpkey_tmo;
static void
ssl_tmpkey_erase (bool timedout = false)
{
  if (!timedout && ssl_tmpkey_tmo)
    timecb_remove (ssl_tmpkey_tmo);
  ssl_tmpkey_tmo = NULL;
  RSA_free (ssl_tmpkey_key);
  ssl_tmpkey_key = NULL;
}
static RSA *
ssl_tmpkey (SSL *ssl, int exp, int keylen)
{
  if (ssl_tmpkey_initno != opt->configno)
    ssl_tmpkey_erase ();

  if (ssl_tmpkey_key)
    return ssl_tmpkey_key;

  ssl_tmpkey_key = RSA_generate_key (keylen, 17, NULL, NULL);

  if (!ssl_tmpkey_tmo)
    ssl_tmpkey_tmo = delaycb (3600, wrap (ssl_tmpkey_erase, true));
  return ssl_tmpkey_key;
}

static DH *
ssl_dhparm (SSL *, int exp, int len)
{
  static qhash<int, DH *> parmtab;
  if (DH **dhp = parmtab[len])
    return *dhp;

  DH *dh = NULL;
  str cachefile = strbuf ("%s/dh%d.pem", opt->etcdir.cstr (), len);
  if (FILE *fp = fopen (cachefile, "r")) {
    dh = PEM_read_DHparams(fp, NULL, NULL, NULL);
    fclose (fp);
  }

  if (!dh) {
    warn << "Generating " << len << "-bit DH parameters... ";
    err_flush ();
    dh = DH_generate_parameters (len, 2, NULL, NULL);
    if (!dh) {
      warn << "failed: " << ssl_err () << "\n";
      return NULL;
    }
    warnx << "done\n";

    int fd = open (cachefile, O_CREAT|O_EXCL|O_WRONLY, 0644);
    if (fd >= 0) {
      if (FILE *fp = fdopen (fd, "w")) {
	PEM_write_DHparams (fp, dh);
	fclose (fp);
      }
      else
	close (fd);
    }
  }

  parmtab.insert (len, dh);
  return dh;
}

static int
verify_cb (int, X509_STORE_CTX *)
{
  return 1;
}

bool
ssl_init ()
{
  if (ssl_initno == opt->configno)
    return opt->ssl_status > 0;
  opt->ssl_status = -1;
  ssl_initno = opt->configno;

  ssl_tmpkey_erase ();

  if (!opt->ssl)
    return false;

  struct stat sb;
  if (stat (opt->ssl_key, &sb) < 0) {
    if (errno == ENOENT)
      warn << "STARTTLS disabled (no file " << opt->ssl_key << ")\n";
    else
      warn << opt->ssl_key << ":  " << strerror (errno) << "\n";
    return false;
  }
  if (sb.st_mode & 044)
    warn << "DANGER: " << opt->ssl_key << " should be read-protected\n";

  if (!ssl_initialized) {
    ssl_initialized = true;
    SSL_load_error_strings ();
    SSL_library_init ();
#ifdef SFS_DEV_RANDOM
    if (RAND_load_file (SFS_DEV_RANDOM, 1024) <= 0)
      warn << "DANGER: " << SFS_DEV_RANDOM ": " << ssl_err () << "\n";
#else /* !SFS_DEV_RANDOM */
    ssl_randfile = opt->etcdir << "/.rnd";
    if (RAND_load_file (ssl_randfile, 1024) <= 0)
      warn << "DANGER: " << ssl_randfile << ": " << ssl_err () << "\n";
#endif /* !SFS_DEV_RANDOM */
    atexit (ssl_shutdown);
  }
  else if (ssl_randfile)
    RAND_write_file (ssl_randfile);

  if (ctx)
    SSL_CTX_free (ctx);
  ctx = SSL_CTX_new (SSLv23_server_method ());
  if (!ctx) {
    warn << "SSL_CTX_new: " << ssl_err () << "\n";
    return false;
  }

  bool verify = true;
  if (SSL_CTX_load_verify_locations (ctx, opt->ssl_ca, NULL) <= 0) {
    warn << opt->ssl_ca << ": " << ssl_err () << "\n";
    verify = false;
  }
  if (verify)
    SSL_CTX_set_verify (ctx, SSL_VERIFY_PEER, verify_cb);

  if (!access (opt->ssl_crl, 0)) {
    X509_STORE *store = SSL_CTX_get_cert_store (ctx);
    X509_LOOKUP *lookup;
    if ((lookup = X509_STORE_add_lookup (store, X509_LOOKUP_file ()))
	&& (X509_load_crl_file(lookup, opt->ssl_crl, X509_FILETYPE_PEM) == 1))
      X509_STORE_set_flags(store, (X509_V_FLAG_CRL_CHECK |
				   X509_V_FLAG_CRL_CHECK_ALL));
  }

  if (SSL_CTX_use_certificate_file (ctx, opt->ssl_cert,
				    SSL_FILETYPE_PEM) <= 0) {
    warn << opt->ssl_cert << ": " << ssl_err () << "\n";
    return false;
  }

  if (SSL_CTX_use_PrivateKey_file (ctx, opt->ssl_key, SSL_FILETYPE_PEM) <= 0) {
    warn << opt->ssl_key << ": " << ssl_err () << "\n";
    return false;
  }

  if (SSL_CTX_check_private_key (ctx) <= 0) {
    warn << opt->ssl_key << ": " << ssl_err () << "\n";
    return false;
  }

  SSL_CTX_set_options (ctx, SSL_OP_ALL|SSL_OP_NO_SSLv2);
  if (opt->ssl_ciphers
      && SSL_CTX_set_cipher_list (ctx, opt->ssl_ciphers) <= 0) {
    warn << "Cipher list " << opt->ssl_ciphers << ": " << ssl_err () << "\n";
    return false;
  }

  SSL_CTX_set_tmp_rsa_callback (ctx, ssl_tmpkey);
  SSL_CTX_set_tmp_dh_callback (ctx, ssl_dhparm);

  u_char sessid[SSL_MAX_SSL_SESSION_ID_LENGTH];
  if (RAND_pseudo_bytes (sessid, sizeof (sessid)) >= 0)
    SSL_CTX_set_session_id_context (ctx, sessid, sizeof (sessid));

  opt->ssl_status = 1;
  return true;
}


str
smtpd::helo_starttls ()
{
  if (!encrypted && !auth_user && opt->ssl_status > 0)
    return "250-STARTTLS\r\n";
  return "";
}

void
smtpd::cmd_starttls (str cmd, str arg)
{
  if (encrypted || auth_user || opt->ssl_status <= 0) {
    respond ("502 command not implemented\r\n");
    return;
  }

  reset ();
  aio << "220 Starting TLS\r\n";
  aiossl *ssl = static_cast<aiossl *> (aio.get ());

  ssl->startssl (ctx, true);
  encrypted = true;

  ssl->verify_cb (wrap (this, &smtpd::set_quota_user));

  cmdwait = true;
  aio->readline (wrap (this, &smtpd::getcmd));
}

void
smtpd::set_quota_user ()
{
  if (!encrypted)
    return;
  aiossl *ssl = dynamic_cast<aiossl *> (aio.get ());
  if (ssl->subject) {
    quota_user = ssl->subject;
    if (opt->ssl >= 2 && trust < TRUST_AUTH)
      trust = TRUST_AUTH;
  }
}

void
smtpd::received_starttls (strbuf r) const
{
  if (!encrypted)
    return;
  aiossl *ssl = dynamic_cast<aiossl *> (aio.get ());
  if (!ssl || !ssl->cipher)
    return;

  r << "    (";
  if (char *vers = SSL_get_cipher_version (ssl->get_ssl ()))
    r << vers;
  else
    r << "SSL";
  int cipher_bits, alg_bits = 0;
  cipher_bits = SSL_get_cipher_bits (ssl->get_ssl (), &alg_bits);
  r << " " << ssl->cipher << " " << cipher_bits << "/" << alg_bits;
      
  if (ssl->subject) {
    r << ",\n    ";
    if (ssl->issuer)
      r << " issuer=" << ssl->issuer << ",";
    r << " subject=" << ssl->subject << ")\n";
  }
  else
    r << ")\n";

}

void
smtpd::env_starttls (vec<str> *envp) const
{
  if (!encrypted)
    return;
  aiossl *ssl = dynamic_cast<aiossl *> (aio.get ());
  if (!ssl || !ssl->cipher)
    return;

  envp->push_back (strbuf () << "SSL_CIPHER=" << ssl->cipher);

  int cipher_bits, alg_bits = 0;
  cipher_bits = SSL_get_cipher_bits (ssl->get_ssl (), &alg_bits);
  envp->push_back (strbuf ("SSL_CIPHER_BITS=%d", cipher_bits));
  envp->push_back (strbuf ("SSL_ALG_BITS=%d", alg_bits));

  if (char *vers = SSL_get_cipher_version (ssl->get_ssl ()))
      envp->push_back (strbuf () << "SSL_VERSION=" << vers);

  if (ssl->issuer)
    envp->push_back (strbuf () << "SSL_ISSUER=" << ssl->issuer);
  if (ssl->issuer_dn)
    envp->push_back (strbuf () << "SSL_ISSUER_DN=" << ssl->issuer_dn);
  if (ssl->subject)
    envp->push_back (strbuf () << "SSL_SUBJECT=" << ssl->subject);
  if (ssl->subject_dn)
    envp->push_back (strbuf () << "SSL_SUBJECT_DN=" << ssl->subject_dn);
}
