/*
 * Copyright (c) 2002, The EROS Group, LLC and Johns Hopkins
 * University. All rights reserved.
 * 
 * This software was developed to support the EROS secure operating
 * system project (http://www.eros-os.org). The latest version of
 * the OpenCM software can be found at http://www.opencm.org.
 * 
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 * 
 * 3. Neither the name of the The EROS Group, LLC nor the name of
 *    Johns Hopkins University, nor the names of its contributors
 *    may be used to endorse or promote products derived from this
 *    software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <opencm.h>
#include <openssl/err.h>

extern int chan_SetNonblocking(Channel *);
extern int chan_connect(Channel *, URI *);

static void ssl_send(Channel *c, const void *buf, size_t len, ChannelCompletionFn fn);
static void ssl_flush(Channel *c, OC_bool andWait);
static void ssl_receive(Channel *c, size_t len, ChannelCompletionFn fn);
static void ssl_pull(Channel *c, OC_bool andWait);
static void ssl_aclose(Channel *c);
static void ssl_close(Channel *c);

static void ssl_connect_shake(Channel *c, OC_bool wait);
static void ssl_accept_shake(Channel *c, OC_bool wait);

static Channel *
chan_init_ssl(ChannelCompletionFn fn)
{
  Channel *c = chan_create();
  c->send = ssl_send;
  c->receive = ssl_receive;
  c->flush = ssl_flush;
  c->pull = ssl_pull;
  c->close = ssl_close;
  c->aclose = ssl_aclose;

  assert(ssl_ctx != NULL);

  c->ctx = ssl_ctx;
  c->ssl = 0;
  c->sbio = 0;

  if (fn)
    c->rCallback = fn;

  return c;
}

#ifdef USE_CALIST
/* Use this only if you need to verify certs against
 * a separate list of CA certs.  We normally use self-signed
 * certs and thus don't use this method. */
static void
check_cert_chain(SSL *ssl, const char *host)
{
  int result;
  X509 *peer;
  char peer_CN[256];

  result =  SSL_get_verify_result(ssl);
  if(result != X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT)
    THROW(ExBadValue, 
	  format("Certificate is not self-signed: %d.\n", result));

  /* Check cert chain. The chain length is
   * automagically checked by OpenSSL when
   * we set the verify depth in the context.
   * All this does is check that the CN matches */
  peer = SSL_get_peer_certificate(ssl);
  X509_NAME_get_text_by_NID(X509_get_subject_name(peer),
      NID_commonName, peer_CN, sizeof(peer_CN));
  if (strcasecmp(peer_CN, host))
    THROW(ExBadValue, "Common name doesn't match host name");
}
#endif

/* There is no SSL-specific code to perform for the initial
 * listening Channel, so just use the plain-jane TCP code: */
Channel *
chan_alisten_ssl(ChannelCompletionFn fn, char *host, unsigned short port)
{
  Channel *c = chan_alisten_tcp(fn, host, port);
  if (c)
    c->handshake_done = TRUE;  /* no handshake needed for listening chan */
  return c;
}

/*
  The errcode parameter really should be long and not something else, becuase
  that is what ERR_get_error returns. It seems highly unlikely that the error
  code would ever be so large it couldn't be represented in 32 bits (since that
  would require a 64 bit long), but to prevent compiler warnings, etc, it
  should stay at unsigned long - JL (9/30/02)
*/
static void
report_ssl_error(int level, unsigned long errcode)
{
 if(errcode == 0)
   log_trace(TRC_SSL, "No OpenSSL error reported (!)\n");
 else if(ERR_reason_error_string(errcode) == 0)
   log_trace(TRC_SSL, "OpenSSL error code %l\n", errcode);
 else
   log_trace(TRC_SSL, "%s\n", ERR_reason_error_string(errcode));
}

static void 
ssl_connect_shake(Channel *c, OC_bool andWait)
{
  int r = 0;
  int err;
  X509 *peer;

  /* Our underlying BIO is non-blocking, so we need to 
   * stay in this loop until we've successfully connected
   * via SSL: */
  for (;;) {
    r = SSL_connect((SSL *)(c->ssl));
    err = SSL_get_error((SSL *)(c->ssl), r);

    switch (err) {
      case SSL_ERROR_NONE: 
	{
	  int result;

	  /* Call any required callback */
	  if (c->hCallback) {
	    ChannelCompletionFn fn = c->hCallback;

	    c->hCallback = NULL;
	    fn(c);
	  }
	  c->handshake_done = TRUE;
      
	  /* Verify server's certificate */
	  peer=SSL_get_peer_certificate((SSL *)(c->ssl));

	  assert(peer);

	  result = SSL_get_verify_result((SSL *)(c->ssl));
	  
	  if (result != X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT) {
	    log_trace(TRC_CRYPTO, "Server cert isn't self-signed or is expired: %d\n", result);
	    c->close(c);
	  }

	  /* Store a user reference in this SSL Channel temporarily */
	  c->peerCert = pubkey_from_X509(peer);

	  return;
	}
      case SSL_ERROR_WANT_READ:
	break;

      case SSL_ERROR_WANT_WRITE:
	break;

      default:
        report_ssl_error(0, ERR_get_error());
	c->close(c);
	return;
	break;
    }

    if (!andWait)
      return;

  }
}

Channel *
chan_connect_ssl(URI *uri) 
{
  Channel *c = chan_init_ssl(NULL);

  if (chan_connect(c, uri) < 0)
    return NULL;

  /* Now do SSL specific stuff */
  c->ssl = SSL_new((SSL_CTX *)(c->ctx));
  c->sbio = BIO_new_socket(c->sockfd, BIO_NOCLOSE);
  SSL_set_bio((SSL *)(c->ssl), (BIO *)(c->sbio), (BIO *)(c->sbio));

  c->shake = ssl_connect_shake;
  c->handshake_done = FALSE;
  c->hCallback = NullChannelCallback;

  return c;
}

Channel *
chan_accept_ssl(Channel *c)
{
  int fd = 0;
  struct sockaddr_in cliaddr;
  int clientlen;
  Channel *client;

  memset(&cliaddr, 0, sizeof(cliaddr));
  clientlen = sizeof(cliaddr);
  fd = accept(c->sockfd, (struct sockaddr *) &cliaddr, &clientlen);
  if (fd < 0) {
    if (errno != EWOULDBLOCK) {
      log_trace(TRC_COMM, "Server Accept failed from %s\n",
		 inet_ntoa(cliaddr.sin_addr));
    }
    return NULL;
  }
  
  /* Ensure that client connections get closed when we run
     subprocesses. */
  fcntl(fd, F_SETFD, FD_CLOEXEC | fcntl(fd, F_GETFD));

  log_trace(DBG_REQUEST, "Processing request from %s\n",
	    inet_ntoa(cliaddr.sin_addr));

  client = chan_init_ssl(NULL);
  client->sockfd = fd;
  if (chan_SetNonblocking(client) < 0) {
    log_trace(TRC_COMM, "Server couldn't make non-blocking Channel for %s\n",
	       inet_ntoa(cliaddr.sin_addr));
    close(fd);
    return NULL;
  }

  /* Now do SSL specific stuff */
  client->sbio = BIO_new_socket(client->sockfd, BIO_NOCLOSE);
  client->ssl  = SSL_new(ssl_ctx);
  SSL_set_bio((SSL *)(client->ssl), (BIO *)(client->sbio), (BIO *)(client->sbio));

  /* Use server version of the handshaking code */
  client->shake = ssl_accept_shake;
  client->handshake_done = FALSE;
  client->hCallback = NullChannelCallback;

  return client;
}

static void
ssl_accept_shake(Channel *c, OC_bool wait)
{
  int r = 0;
  X509 *peer;

  for (;;) {
    OC_bool goodAccept = TRUE;
    TRY {
      r = SSL_accept((SSL *)(c->ssl));
    }
    CATCH(ExConnLost) {
      log_trace(TRC_SSL, "Client connection terminated during SSL handshake.\n");
      goodAccept = FALSE;
    }
    END_CATCH;

    if (!goodAccept) {
      c->aclose(c);
      return;
    }

    switch (SSL_get_error((SSL *)(c->ssl), r)) {
    case SSL_ERROR_NONE:
      goto successful_accept;
      break;

    case SSL_ERROR_WANT_READ:
      break;
	
    case SSL_ERROR_WANT_WRITE:
      break;

    default:
      {
	log_trace(TRC_COMM, "Server SSL accept error.\n");
        report_ssl_error(0, ERR_get_error());
	c->aclose(c);
	return;
      }
    }

    if (!wait)
      return;
  } 

successful_accept:
  /* client authentication */
  /* (We'll never really have a client authentication problem here because
   * the client challenge is handled above in the SSL_accept call.  However
   * the app can do some more checking with the actual contents of the client's
   * certificate here) */
  peer = NULL;

  TRY {
    peer=SSL_get_peer_certificate((SSL *)(c->ssl));
  }
  CATCH(ExConnLost) {
    log_trace(TRC_SSL, "Caught ExConnLost while soliciting peer certificate.\n");
    log_trace(TRC_SSL, "NOTICE: This may be illegal in some states.\n");
    c->aclose(c);
  }
  END_CATCH;

  if (peer) {

    int result = SSL_get_verify_result((SSL *)(c->ssl));
    
    if (result != X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT) {
      log_trace(TRC_CRYPTO, "Client cert isn't self-signed or is expired.\n");
      report_ssl_error(0, ERR_get_error());
      c->aclose(c);
    }

    /* Store a user reference in this SSL Channel temporarily*/
    c->peerCert = pubkey_from_X509(peer);

    /* Callback */
    if (c->hCallback) {
      ChannelCompletionFn fn = c->hCallback;
      c->hCallback = NULL;
      fn(c);
    }
    c->handshake_done = TRUE;
  }
}

static void
ssl_flush(Channel *c, OC_bool blocking)
{
  int r = 0;

  while (c->writeQueue) {
    PendingWrite *pw = c->writeQueue;

    while (pw->sent < pw->len) {
      TRY {
        r = SSL_write(c->ssl, pw->buf + pw->sent, pw->len - pw->sent);
      } /* end of the TRY block */
      CATCH(ExConnLost) {
	log_trace(ERR_COMM, "Conn lost on SSL write. r = %d\n", r);
	c->aclose(c);
      }
      DEFAULT(ex) {
	log_trace(ERR_COMM, "Misc error on SSL write. r = %d\n", r);
	c->aclose(c);
	RETHROW(ex);
      }
      END_CATCH;

      /* Be careful: we can't do this break inside the CATCH (above), because
         END_CATCH must be executed, or _bad_ _things_ can happen.
      */
      if(c->closed)
	return;

      switch(SSL_get_error((SSL *)(c->ssl), r)) {
      case SSL_ERROR_NONE:
	{
	  pw->sent += r; 
	  break;
	}
      case SSL_ERROR_WANT_READ:
	if (blocking) chan_blockForReading(c, NULL);
	break;
      case SSL_ERROR_WANT_WRITE:
	{
	  c->pendingWrite = TRUE;
	  if (blocking)
	    chan_blockForWriting(c, NULL);
	  break;
	}
      default:
	{
	  c->writeQueue = 0;
	  c->aclose(c);
	  return;
	}
      }

      /* Note that for algorithmic reasons this can only occur if SSL_ERROR_NONE */
      if (pw->sent == pw->len) {
	c->pendingWrite = FALSE;
	break;
      }

      if (!blocking) return;	/* not done, not waiting */
 
      if (c->closed)
	return;
    }

    {
      ChannelCompletionFn fn = pw->callback;
      pw->callback = NULL;
      fn(c);
      c->writeQueue = pw->next;
    }
  }
}

static void
ssl_pull(Channel *c, OC_bool blocking)
{
  int nread = 0;

  if (!c->rCallback) return;

  /* This is a bit odd. We could have a pendingRead either because 
   * of an actual read call or because we were waiting for a connection 
   * to occur on this channel. In the latter case we want to call 
   * listen() rather than read(). */
  if (c->read_wanted == 0) {
    if (blocking) chan_blockForReading(c, NULL);
  }
  else {
    for (;;) {
      /* At this point, the network buffer is ready to read and
       * we need to keep doing an SSL read until the SSL buffer is empty.
       * Then, check if client is waiting and check the network buffer again */
      while (c->read_so_far < c->read_wanted) {
	size_t limit = c->read_wanted - c->read_so_far;
	limit = min(limit, CHAN_RDBOUND);

        TRY {
          nread = SSL_read((SSL *)(c->ssl), c->read_scratch, limit);
        }
        CATCH(ExConnLost) {
	  c->aclose(c);
	  log_trace(ERR_COMM, "Conn lost on SSL read. nread = %d\n", nread);
        }
        END_CATCH;

	if (nread == 0) {
	  log_trace(ERR_COMM, "Read attempt over SSL returned zero bytes. "
		    "wanted %d have %d limit %d error %s"
		    "SSL err = %d (Closing channel...)\n", 
		    c->read_wanted, c->read_so_far, limit, strerror(errno),
		    SSL_get_error((SSL *)(c->ssl), nread));
	  c->aclose(c);
	  return;
	}

	switch(SSL_get_error((SSL *)(c->ssl), nread)) {
	case SSL_ERROR_NONE:
	  /* Only advance the buffer ptr if the SSL buffer is empty */
	  buffer_append(c->read_buf, c->read_scratch, nread);
	  c->read_so_far += nread;
	  break;
	case SSL_ERROR_WANT_READ:
	  if (blocking) chan_blockForReading(c, NULL);
	  break;
	case SSL_ERROR_WANT_WRITE:
	  if (blocking) chan_blockForWriting(c, NULL);
	  break;
	default:
	  c->aclose(c);
	  THROW(ExIoError, format("I/O error on channel %d read: %s", 
				  c->sockfd, strerror(errno)));
	}

        if(c->closed)
	  return;

	c->pendingRead = (SSL_pending(c->ssl) > 0);

	/* Check if we have what we asked for: */
	if (c->read_so_far == c->read_wanted)
	  goto callback;

	/* If we aren't done and there is pendingRead stuff,
	   chan_select() will call us again. */
	if (!blocking) return;
      }
    }
  }

 callback:
  if (!c->closed) {
    ChannelCompletionFn fn = c->rCallback;
    c->rCallback = NULL;
    fn(c);
  }
}

/**
  * Create a PendingWrite object and queue it on this Channel's
  * outbound (write) queue.
  */
static void
ssl_send(Channel *c, const void *buf, size_t len, ChannelCompletionFn fn) 
{
  PendingWrite *pw = (PendingWrite *)GC_MALLOC(sizeof(PendingWrite));

  if (!fn) fn = NullChannelCallback;

  pw->buf = buf;
  pw->len = len;
  pw->sent = 0;
  pw->callback = fn;

  if (c->writeQueue == 0) {
    c->writeQueue = pw;
  }
  else {
    PendingWrite *q = c->writeQueue;
    while (q->next)
      q = q->next;
    q->next = pw;
  }

} 

static void
ssl_receive(Channel *c, size_t len, ChannelCompletionFn fn) 
{
  assert(c->rCallback == 0);

  if (!fn) fn = NullChannelCallback;

  c->rCallback = fn;
  c->read_buf = buffer_create(); /* new buffer for each read! */
  c->read_wanted = len;
  c->read_so_far = 0;
} 

static void
ssl_aclose(Channel *c)
{
  log_trace(TRC_COMM, 
	     "Closing (async) SSL connection (connection id=%d, heap=%d)\n",
	     c->connection_id, GC_get_heap_size());
  c->closed = TRUE;
}

static void
ssl_close(Channel *c)
{
  log_trace(TRC_COMM,
	     "Shutting down SSL connection (connection id=%d, heap=%d)\n", 
	     c->connection_id, GC_get_heap_size());
  c->rCallback = NULL;

  TRY { 
    SSL_shutdown(c->ssl);
  }
  CATCH(ExConnLost) {
    /* do nothing */
  }
  END_CATCH;

  SSL_free(c->ssl);

  TRY { 
    close(c->sockfd);
    c->closed = TRUE;
  }
  CATCH(ExConnLost) {
    /* do nothing */
  }
  END_CATCH;
}
