/*
 * Copyright (c) 2002, The EROS Group, LLC and Johns Hopkins
 * University. All rights reserved.
 * 
 * This software was developed to support the EROS secure operating
 * system project (http://www.eros-os.org). The latest version of
 * the OpenCM software can be found at http://www.opencm.org.
 * 
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 * 
 * 3. Neither the name of the The EROS Group, LLC nor the name of
 *    Johns Hopkins University, nor the names of its contributors
 *    may be used to endorse or promote products derived from this
 *    software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <opencm.h>
#include "opencmserver.h"
#include "../client/command.h"
#include <syslog.h>

#define MAX_CHANNELS 2000

static OC_bool quit;
static Channel *activeChannels[MAX_CHANNELS];
static int max_fd_so_far;

/* Here are the server-side callbacks, in order
 * of execution */
void svr_DoListen(Channel *listenChan);
void svr_DoClientSetup(Channel *c);
void svr_StartSession(Channel *c);
void svr_StartRequest(Channel *c);
void svr_GetRequest(Channel *c);
void svr_DoRequest(Channel *, Request *);

/* Misc server routines */
void svr_SetupSSL(void);
int  svr_OpenRepos(const char *path);
void svr_Cleanup(void);

/* One server manages one repository */
Repository *localRepos = NULL;

/* Used to verify that client knows to which repository
 *  s/he is connecting. */
static OC_bool 
check_repos(const char *path) {
#if 0
  struct stat s1, s2;
  stat(localRepos->uri->path, &s1);
  stat(path, &s2);
  return (s1.st_dev  == s2.st_dev) && (s1.st_ino  == s2.st_ino);
#else
  return TRUE;
#endif
}

static OC_bool
svr_HasReposAccess(Session *s)
{
  if (s->repos->authUser == 0) {
    report(1, "User is unknown to this Repository.\n");
    s->channel.aclose(&(s->channel));
    return FALSE;
  }
  else if (s->repos->authAccess == 0) {
    report(1, "User doesn't have access to this Repository.\n");
    s->channel.aclose(&(s->channel));
    return FALSE;
  }
  else if (s->repos->authAccess & ACC_REVOKED) {
    report(1, "User has been revoked.\n");
    s->channel.aclose(&(s->channel));
    return FALSE;
  }
  return TRUE;
}

/* Server manages one and only one repository.  This routine
 * opens the server's repository. */
int
svr_OpenRepos(const char *path)
{
  OC_bool connected = FALSE;

  /* Validate the Repository before doing anything else: */
  if (!path) {
    report(0,"You must specify a repository.\n");
    return -1;
  }
  /* The server's repository is specified via a uri.  The scheme
   * of the uri is a meta-type and must correspond to any of the
   * file system based repositories we support.  Currently, we
   * support two:  the 'fs' and the 'rcsfs'.  This check just
   * ensures that the server's repository is specified correctly. */
  if (strncmp(path, CM_REPOS_FILE, strlen(CM_REPOS_FILE)) != 0) { 
    report(0,"You must specify a file system based repository.\n");
    return -1;
  }

  /* Here's where the actual underlying repository type is determined */
  localRepos = repository_open(path);
  if (!localRepos) {
    return -1;
  }

  TRY {
    repos_Connect(localRepos, NULL);
    connected = TRUE;
  }
  DEFAULT(AnyException) {
  }
  END_CATCH;
  if (connected == FALSE)
    return -1;

  /* If port was specifically specified, set it accordingly: */
  if (opt_AlternatePort) {
    localRepos->uri->port = strtol(opt_AlternatePort, NULL, 10); 
  }

  return 0;
}

/**
 * Perform a best-effort cleanup and then exit.
 */
void
svr_Cleanup(void)
{
  unsigned u;

  report(1, "Closing...\n");

  /* Go through all the active channels, attempting to do an orderly
     shutdown on them. 

     Issue: shouldn't we attempt to give the channels a chance to
     drain? */

  for (u = 0; u <= max_fd_so_far; u++) 
    if (activeChannels[u])
      activeChannels[u]->close(activeChannels[u]);

  if (localRepos)
    repos_Disconnect(localRepos);

  do_exit(0);
}

/**
  * Signal handling function to quit gracefully.
  */
static void 
svr_End(int sigid) 
{
  ssl_destroy_ctx(ssl_ctx);
  svr_Cleanup();
}

void
svr_SetupSSL()
{
  const char *path = path_join(localRepos->uri->path, R_CONFIG);

  const char *keyfile  = xstrdup(path_join(path,R_KEY_FILE));
  const char *certfile = xstrdup(path_join(path,R_CERT_FILE));

  SSL_load_error_strings();
  SSL_library_init();

  ssl_ctx = ssl_init_ctx(keyfile, certfile);

#if 0
  if (dhfile) 
    load_dh_params(ssl_ctx, dhfile);
#endif

  generate_eph_rsa_key(ssl_ctx);
}

/**
  * Callback function for parsing raw bytes received
  * by a client into a formal Request object.
  * If a Request object cannot be formed then
  * we close the client's Channel.
  */
void
svr_GetRequest(Channel *c)
{
  Request *request;
  Buffer *buf = buffer_create();
  SDR_stream *strm;

  int compressionLevel = ((Session *)c)->compressionLevel;

  if (c->closed)
    return;

  /* c->read_buffer is a buffer containing a valid request to be
     deserialized */
  buffer_append(buf, c->read_buffer, c->total_read);

  if (compressionLevel)
    buf = buffer_decompress(buf);

  strm = stream_fromBuffer(buf, SDR_WIRE);
  request = sdr_read(SERVERREQUEST, strm);
  stream_close(strm);

  /* Make one last error check.  If sdr_read returns
   * NULL then we don't have a well-formed message
   * from the client and we can terminate this client: */
  if (request == NULL) {
    c->aclose(c);
    return;
  }

  svr_DoRequest(c, request);
  GC_gcollect();

  /* Restore the Channel's read callback
   * because the client is probably going to have
   * more requests. */
  chan_aread(c, sizeof(reqlen_t), svr_StartRequest);
}


/* Used only in StartSession */
static void
send_reply(Channel *c, Reply *reply)
{
  SDR_stream *reply_strm;
  Buffer *buf;
  reqlen_t total = 0;
  
  /* Make sure this buffer is malloc'd:  chan_awrite() relies
   * on the buffer pointer's validity well beyond the scope of this
   * local declaration block. */
  char *totbuf = (char *)GC_MALLOC_ATOMIC(sizeof(reqlen_t));

  reply_strm = stream_createBuffer(SDR_WIRE);
  sdr_write(SERVERREPLY, reply_strm, reply);

  buf = stream_asBuffer(reply_strm);
  total = buffer_length(buf);
  total = htonl(total);
  memcpy(totbuf, &total, sizeof(total));

  chan_awrite(c, totbuf, sizeof(total), NULL);

  {
    ocmoff_t end = buffer_length(buf);
    ocmoff_t pos = 0;

    while (pos < end) {
      BufferChunk bc = buffer_getChunk(buf, pos, end - pos);
      assert(bc.len <= (end - pos));

      chan_awrite(c, bc.ptr, bc.len, NULL);

      pos += bc.len;
    }
  }
}

static void
svr_CheckRepos(Channel *c)
{
  Request *request;
  Reply *reply;
  SDR_stream *strm;

  if (c->closed)
    return;

  /* c->read_buffer is a buffer containing a valid request to be
     deserialized */
  strm = stream_fromMemory(c->read_buffer, c->total_read, SDR_WIRE);
  request = sdr_read(SERVERREQUEST, strm);
  stream_close(strm);

  /* Each client session must begin with the STARTSESSION message 
   * which must have a repository path as its single argument.  Here
   * we simply compare the repository path that user specified with
   * the one the server is actually running.  They both must resolve
   * to the same physical directory in order to continue. */
  if (request->opcode == OP_StartSession) {
    if (vec_size(request->args) == 2) {
      WireString *arg = (WireString *)vec_fetch(request->args, 0);
      WireUnsigned *protocol = (WireUnsigned *)vec_fetch(request->args, 1);
      if (check_repos(arg->value)) {
	Session *client = (Session *)c;
	Repository *repos = client->repos;

	/* We need to return multiple values to client */
        ObVec *retval = obvec_create();

	/* Here's where we will determine which protocol version
	 * server can support.  Right now, just send back what client
	 * sent us */

	/* Attach repository ID */
	obvec_append(retval, wires_create(repos->repositoryID));

	/* Attach protocol version we will use */

	if (protocol->value < CM_LEAST_PROTO_VER)
	  obvec_append(retval, wireu_create(CM_PROTO_NOCOMMONVER));
	else
	  obvec_append(retval, wireu_create(CM_PROTO_VER));

	reply = reply_create(request->reqID, (Serializable *)retval);
	send_reply(c, reply);

	/* Set the Channel to wait for client's first request */
	chan_aread(c, sizeof(reqlen_t), svr_StartRequest);

	if (protocol->value < CM_LEAST_PROTO_VER)
	  c->aclose(c);

	return;

      } else {
	/* The comparison check failed: */
	Serializable *s = (Serializable *)
	  wireExcpt_create(ExBadValue, __FILE__, __LINE__,
			   "Mismatched repository name");

	reply = reply_create(request->reqID, s);
	send_reply(c, reply);
	c->aclose(c);
      }

    } else {
      Serializable *s = (Serializable *)
	wireExcpt_create(ExMalformed, __FILE__, __LINE__,
			 "Wrong number of arguments to OP_StartSession");

      reply = reply_create(request->reqID, s);

      send_reply(c, reply);
      c->aclose(c);
    }
  }
  else {
    Serializable *s = (Serializable *)
      wireExcpt_create(ExBadOpcode, __FILE__, __LINE__,
		       "Expected OP_StartSession");

    reply = reply_create(request->reqID, s);
    send_reply(c, reply);
    c->aclose(c);
  }
}

void
svr_StartSession(Channel *c)
{
  reqlen_t len;

  if (c->closed)
    return;

  memcpy(&len, c->read_buffer, sizeof(reqlen_t));
  len = ntohl(len);
  chan_aread(c, len, svr_CheckRepos);
}

/**
  * Callback function to initiate getting the raw bytes
  * from a client's Channel.  Our protocol demands that the
  * client sends the length of his message first.  This
  * function attempts to capture that length and then, when
  * successful, uses the svr_GetRequest callback to capture
  * the actual message.
  */
void
svr_StartRequest(Channel *c)
{
  reqlen_t len;

  if (c->closed)
    return;

  memcpy(&len, c->read_buffer, sizeof(reqlen_t));
  len = ntohl(len);
  chan_aread(c, len, svr_GetRequest);
}

/**
  * Callback for our server to accept incoming client
  * connection requests.
  */
void
svr_DoListen(Channel *listenChan)
{
  Channel *new_client = NULL;
  Session *client;

  new_client = chan_accept_ssl(listenChan);

  if (!new_client) {
    report (0, "Client connect failed.\n");

  } else {

    /* Add client to active list */
    if (new_client->sockfd < MAX_CHANNELS) {
      /* Set a timeout for each client Channel so we prevent
       * idle Channels. Until we have a valid SSL session, give them
       * 15 seconds to live: */

#ifdef CLIENT_TIMEOUT
      chan_set_max_idle_time(new_client, CLIENT_TIMEOUT); 
#else
      chan_set_max_idle_time(new_client, 300);  /* 5 min */
#endif
	
      /* Set up an initial session with access to the local
	 repository. This is enough to do GetUser, but won't be enough
	 to do anything else. */
      client = session_create(repository_dup(localRepos));
      memcpy(&(client->channel), new_client, sizeof(Channel));

      /* Add channel to a master list of active Channels */
      activeChannels[client->channel.sockfd] = &(client->channel);
      chan_ashake(&(client->channel), svr_DoClientSetup);
      if (new_client->sockfd > max_fd_so_far) {
	max_fd_so_far = new_client->sockfd;
      }
    } else {
      report(0, "Client connect failed: exceeded max channels.\n");
      new_client->aclose(new_client);
    }
  }

  /* need to restore the read callback for the listening Channel: */
  listenChan->rCallback = svr_DoListen;
}

/* Callback for completing the client channel setup after the
 * network protocol handshake has completed.  */
void
svr_DoClientSetup(Channel *c)
{
  OC_bool connected = FALSE;
  Session *s = (Session *)c;

  /* Ok, now that the protocol handshake is finished, we can retrieve
   * the user info and perform proper per-user authentication */
  TRY {
    repos_Connect(s->repos, c->peerCert);
    connected = TRUE;
  }
  DEFAULT(AnyException) {
    c->aclose(c);
  }
  END_CATCH;
  if (connected == FALSE)
    return;

  s->repos = authrepository_wrap(s->repos, c->peerCert);

  /* Since we want to eliminate invalid clients ASAPly,
   * we can go ahead and verify that this User at least
   * has read access to this Repository. */
  if (!svr_HasReposAccess(s)) {
    report(1, "Rejected unauthorized connection\n");
    c->aclose(c);
    return;
  }

  report(1, "%s: Accepted connect from %s (connection id=%d, heap=%d)\n",
	 os_GetISOTime(),
	 pubkey_GetEmail(c->peerCert), c->connection_id,
	 GC_get_heap_size());

  chan_aread(c, sizeof(reqlen_t), svr_StartSession);
}

/****************************************************
* M A I N   processing loop.
*
* This server is based loosely on the classic
* non-blocking TCP socket server.  However, we use
* a Channel object, as opposed to using TCP sockets
* directly.  This allows us to use any lower level
* network protocol that we wish.  The main processing
* loop of this server uses a variant of the 'select'
* call to determine which Channel objects are available
* for reading and/or writing.  Buried in the 'select'
* call is the logic for kicking off asynchronous
* reads and/or writes for each available Channel.
*
****************************************************/
void 
opencm_server(WorkSpace *ws, int argc, char *argv[]) 
{
  struct sigaction quit_action, old_action;
  int u;
  Channel *listenChan;
  struct timeval tv;

  /* Become a daemon unless debugging has been requested: */
  if (opt_Foreground)
    report(0, "OpenCM running in foreground\n");
  else
    os_daemonize();

  /* Adjust the error reporting so that the server doesn't barf when
     an error needs to be logged. */
  server_mode = 1;

#ifndef LOG_PERROR
  #define LOG_PERROR 0
#endif
  openlog("opencm", LOG_CONS|LOG_PERROR|LOG_PID, LOG_DAEMON);

  /* Default server port number was established by main() */

  /* Open this server's repository */
  if (svr_OpenRepos(argv[0]) < 0)
    THROW(ExNoConnect, "Could not connect to repository");

  /* Now do SSL setup: */
  svr_SetupSSL();

  /* Set up signal handlers for shutting down server */
  quit_action.sa_handler = svr_End;
  sigemptyset(&quit_action.sa_mask);
  quit_action.sa_flags = 0;
  sigaction(SIGINT, &quit_action, &old_action);
  sigaction(SIGHUP, &quit_action, &old_action);
	
  quit = 0;
	
  /* This var is used to keep track of all our client connections
     (as far as doing partial reads from their Channels) */
  for (u = 0; u < MAX_CHANNELS; u++) 
    activeChannels[u] = NULL;

  /* Create listening Channels here.  These Channels
   * are strictly for listening for incoming client requests
   * on a specified port */
  listenChan = 
    chan_alisten_ssl(svr_DoListen, "0.0.0.0",
	(localRepos->uri->port ? localRepos->uri->port : opencmport));

  if (listenChan == NULL)
    THROW(ExNoConnect, "Error setting up listening channel");

  if (listenChan->sockfd < MAX_CHANNELS) {
    activeChannels[listenChan->sockfd] = listenChan;
    if (listenChan->sockfd > max_fd_so_far)
      max_fd_so_far = listenChan->sockfd;
  }
  else
    THROW(ExNoConnect, "Somehow you exceeded the max Channels allowed");

  listenChan =
    chan_alisten_ssl(svr_DoListen, "::",
	(localRepos->uri->port ? localRepos->uri->port : opencmport));

  if (listenChan) {
    if (listenChan->sockfd < MAX_CHANNELS) {
      activeChannels[listenChan->sockfd] = listenChan;
      if (listenChan->sockfd > max_fd_so_far)
	max_fd_so_far = listenChan->sockfd;
    }
    else
      THROW(ExNoConnect, "Somehow you exceeded the max Channels allowed");
  }
  else {
    report(0, "IPv6 support not configured in this kernel.\n"
	   "    Listening for IPv4 connections only.\n");
  }

  /* Main processing loop */
  while(!quit) {
    TRY {

      tv.tv_sec  = 0;
      tv.tv_usec = 250000;
      (void) chan_select(activeChannels, max_fd_so_far, &tv);
      GC_gcollect();
    }
    DEFAULT(ex) {
    }
    END_CATCH;
  }

  svr_Cleanup();
}
