/*
 * Copyright (c) 2002, The EROS Group, LLC and Johns Hopkins
 * University. All rights reserved.
 * 
 * This software was developed to support the EROS secure operating
 * system project (http://www.eros-os.org). The latest version of
 * the OpenCM software can be found at http://www.opencm.org.
 * 
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 * 
 * 3. Neither the name of the The EROS Group, LLC nor the name of
 *    Johns Hopkins University, nor the names of its contributors
 *    may be used to endorse or promote products derived from this
 *    software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <opencm.h>

static unsigned int global_connection_counter = 0;

int chan_SetNonblocking(Channel *c);
int chan_connect(Channel *c, URI *uri);
  
void
NullChannelCallback(Channel * fn)
{
  /* Do absolutely nothing, but have the decency to do it quickly. */
}

/* Used to send bytes via the Channel and blocking until 
 * complete (or error): 
 *
 * c = the Channel over which to send bytes
 * buf = pointer to the array of bytes you wish to send
 * len = total number of bytes of interest in buf
 *
 * returns TRUE if bytes were completely sent with no error
 * returns FALSE if transmission generated an error  
 */
void
chan_write(Channel *c, const void *buf, size_t len, struct timeval *timeout)
{
  TRY {
    c->send(c, buf, len, NULL /* no callback */);

    c->flush(c, TRUE);	/* flush any pending output, possibly triggering callback */
  }
  CATCH(ExConnLost) {
    xprintf("Caught ExConnLost in chan_write, which should have "
	    "been handled by chan_flush()\n");
    c->aclose(c);
    RETHROW(ExConnLost);
  }
  END_CATCH;
}

/* Used to read bytes via the Channel and blocking until 
 * complete (or error): 
 *
 * c = the Channel over which to send bytes
 * buf = pointer to memory where Channel should place the bytes
 * len = total number of bytes to attempt to read
 *
 * returns TRUE if len bytes were successfully read and placed in buf
 * returns FALSE if reception generated an error  
 */
void
chan_read(Channel *c, void *buf, size_t len, struct timeval *timeout)
{
  c->pull(c, TRUE);  /* finish any pending read, possibly triggering callback */

  if (c->closed) {
    c->close(c);
    return;
  }

  /* Queue the request: */
  c->receive(c, buf, len, NULL /* no callback */);

  TRY {
    c->pull(c, TRUE);
  }
  CATCH(ExConnLost) {
    xprintf("Caught ExConnLost in chan_read()!\n");
    c->close(c);
    RETHROW(ExConnLost);
  }
  END_CATCH;
}

void
chan_shake(Channel *c)
{
  if (c->handshake_done)
    return;

  c->shake(c, TRUE);

  c->handshake_done = TRUE;
}

void
chan_ashake(Channel *c, ChannelCompletionFn fn)
{
  c->handshake_done = FALSE;
  c->hCallback = fn;
}

void
chan_awrite(Channel *c, const void *buf, size_t len, ChannelCompletionFn fn)
{
  c->send(c, buf, len, fn);
  /* Initiate it aggressively (because I'm shap, and I do everything that way :-) */
//  c->flush(c, FALSE);
}

#define PLAUSIBLE_REQUEST_BUFFER_SZ  1024
void
chan_aread(Channel *c, size_t len, ChannelCompletionFn fn)
{
  /* Channel must not have any pending output! */
  assert(c->rCallback == NULL);

  if (c->rbuf_limit < len) {
    c->rbuf_limit = max(len, PLAUSIBLE_REQUEST_BUFFER_SZ);
    if (c->read_buffer)
      c->read_buffer = (void *)GC_realloc(c->read_buffer, c->rbuf_limit);
    else
      c->read_buffer = (void *)GC_MALLOC_ATOMIC(c->rbuf_limit);
  }

  /* Set up the asynchronous read */
  c->receive(c, c->read_buffer, len, fn);
}

void
chan_blockForReading(Channel *c, struct timeval *tv)
{
  int ready = 0;
  fd_set rset, wset;

  if (c->closed)
    THROW(ExConnLost, "Channel is closed");

  /* If the intermediate buffer is ready, bypass the select call and just
   * return TRUE: */
  if (c->pendingRead)
    return;

  FD_ZERO(&rset);
  FD_ZERO(&wset);
  FD_SET(c->sockfd, &rset);

  TRY {
    ready = select(c->sockfd+1, &rset, &wset, NULL, tv);
  }
  CATCH(ExConnLost) {
    report(1, "Caught ExConnLost in chan_can_read(): connection id=%d\n",
           c->connection_id);
  }
  END_CATCH;

  if (ready <= 0)
    THROW(ExIoError, 
	  format("I/O error in select() for reading: %s", strerror(errno)));
}

void
chan_blockForWriting(Channel *c, struct timeval *tv)
{
  int ready = 0;
  fd_set rset, wset;
  
  if (c->closed)
    THROW(ExConnLost, "Channel is closed");

  FD_ZERO(&rset);
  FD_ZERO(&wset);
  FD_SET(c->sockfd, &wset);

  TRY {
    ready = select(c->sockfd+1, &rset, &wset, NULL, tv);
  }
  CATCH(ExConnLost) {
    report(1, "Caught ExConnLost in chan_can_write(): connection id=%d\n",
           c->connection_id);
    assert(FALSE);
  }
  END_CATCH;

  if (ready <= 0)
    THROW(ExIoError, 
	  format("I/O error in select() for writing: %s", strerror(errno)));
}

int
chan_select(Channel **chanset, int maxfd, struct timeval *tv)
{
  int ready = 0;
  int u = 0;
  time_t now = time(NULL);

  fd_set readset, writeset, exceptset;
  FD_ZERO(&readset);
  FD_ZERO(&writeset);
  FD_ZERO(&exceptset);

  /* For each Channel, if we alrady know we have
   * pending input or output we can return with those.
   * Otherwise, check the bit flags for each and return
   * accordingly. 
   */
  TRY {
    for (u = 0; u <= maxfd; u++) {
      if (chanset[u] == NULL)
	continue;

      if (chan_idle_too_long(chanset[u])) {
	if (opt_TraceProtocol)
	  xprintf("<-- Closing %d: idle timeout -->\n", chanset[u]->connection_id);
	chanset[u]->aclose(chanset[u]);
	/* The channel will be killed by the closed check just below */
      }

      /* Try to complete any protocol handshaking */
      if (!chanset[u]->handshake_done) 
	chanset[u]->shake(chanset[u], FALSE);

      /* Remove and close any Channels that were aclosed: */
      if (chanset[u]->closed) {
	report(1, "Connection dropped (connection id=%d, heap=%d)\n", 
	       chanset[u]->connection_id, GC_get_heap_size());
	chanset[u]->close(chanset[u]);
	chanset[u] = NULL;
	continue;
      }

      /* OK, if we're doing SSL or any similar protocol, we
       * need to flush or pull any pending Channels before blocking
       * on the select call below.  (That way we don't
       * arbitrarily prioritize the blocking select when we 
       * already have data we can process.) */
      if (chanset[u]->pendingRead) {
	chanset[u]->last_activity = now;
	chanset[u]->pull(chanset[u], FALSE);
      }

      if (chanset[u]->pendingWrite) {
	chanset[u]->last_activity = now;
	chanset[u]->flush(chanset[u], FALSE);
      }

      if (chanset[u]->rCallback)
	FD_SET(chanset[u]->sockfd, &readset);

      if (chanset[u]->writeQueue)
	FD_SET(chanset[u]->sockfd, &writeset);

      /* FIX: If there is an exception, don't we always want to know? */
      FD_SET(chanset[u]->sockfd, &exceptset);
    }
  }
  CATCH(ExConnLost) {
    xprintf("Caught ExConnLost in chan_select pre-check phase\n");
  }
  END_CATCH;

  /* Only do low-level select call as last resort: */
  TRY {
    ready = select(maxfd+1, &readset, &writeset, &exceptset, tv);
  }
  CATCH(ExConnLost) {
    xprintf("Caught ExConnLost in chan_select select call\n");
  }
  END_CATCH;

  if (ready == -1) {
    report(0, "Came off select with -1 and %s\n", strerror(errno));
    assert(0);
  }

  /* For each channel that has something to do, do it: */
  for (u = 0; u <= maxfd; u++) {
    if (chanset[u] == NULL)
      continue;

    if (FD_ISSET(chanset[u]->sockfd, &writeset)) {
      chanset[u]->last_activity = now;
      chanset[u]->flush(chanset[u], FALSE);
    }

    if (FD_ISSET(chanset[u]->sockfd, &readset)) {
      chanset[u]->last_activity = now;
      chanset[u]->pull(chanset[u], FALSE);
    }

    if (FD_ISSET(chanset[u]->sockfd, &exceptset)) {
      /* Probably wrong, but Fuck 'em if they can't take a joke. */
      /* FIX:  chanset[u]->flags |= CFL_WANTCLOSE; */
      chanset[u]->flush(chanset[u], FALSE);
    }
  }

  return ready;
}

Channel *
chan_create(void)
{
  Channel *c = (Channel *)GC_MALLOC(sizeof(Channel));

  c->pendingRead = FALSE;
  c->pendingWrite = FALSE;

  c->rCallback = NULL;

  c->send = NULL;
  c->receive = NULL;
  c->flush = NULL;
  c->pull = NULL;
  c->close = NULL;
  c->aclose = NULL;
  c->closed = FALSE;

  c->read_buffer = NULL;
  c->read_so_far = 0;
  c->total_read = 0;
  c->rbuf_limit = 0;

  c->writeQueue = NULL;

  c->uri = NULL;

  c->ctx = 0;
  c->ssl = 0;
  c->sbio = 0;
  c->peerCert = 0;
  c->xinfo = 0;

  c->check_idleness = FALSE;
  c->max_idle_time = 0;
  c->last_activity = 0;
  c->connection_id = global_connection_counter++;

  return c;
}

void
chan_set_max_idle_time(Channel *c, time_t seconds)
{
  c->check_idleness = TRUE;
  c->max_idle_time = seconds;
  if ((c->last_activity = time(NULL)) < 0) {
    report(1, "Error setting max time: this Channel will close prematurely.\n");
  }
}

OC_bool
chan_idle_too_long(Channel *c)
{
  return c->check_idleness && (time(NULL) - c->last_activity > c->max_idle_time);
}

int
chan_SetNonblocking(Channel *c)
{
  int result = -1;
  int val;
      
  TRY {
    val = fcntl(c->sockfd, F_GETFL, 0);
    result = fcntl (c->sockfd, F_SETFL, val | O_NONBLOCK);
  }
  CATCH(ExConnLost) {
    xprintf("Caught ExConnLost in chan_SetNonblocking call\n");
  }
  END_CATCH;

  if (result < 0)
    return -1;

  return 0;
}

#ifdef HAVE_GETADDRINFO
int
chan_connect(Channel *c, URI *uri)
{
  struct addrinfo hints, *res, *res0;
  char port[5];
  int i,retry,error;

  /* NI_WITHSCOPEID was never incorporated into the IPv6
   * standard. Unfortunately, early versions of OpenBSD rely on it
   * being set, so we need to use it if the macro is present. */
#ifdef NI_WITHSCOPEID
  const int niflags = NI_NUMERICHOST | NI_WITHSCOPEID;
#else
  const int niflags = NI_NUMERICHOST;
#endif

  if (!c)
    return -1;

  memset(&hints, 0, sizeof(hints));
  hints.ai_family = PF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_flags = AI_CANONNAME;

  snprintf(port, 5, "%d", (uri->port == 0) ? opencmport : uri->port);

  error = getaddrinfo(uri->netloc, port, &hints, &res0);

  if (error) {
    if (error == EAI_SERVICE)
      report(0, "%s: bad port", &port);
    else
      report(0, "(%d)%s:%s %s", error, uri->netloc, port, gai_strerror(error));

    if (h_errno)
      herror(uri->netloc);

    return -1;
  }

  i = -1;
  retry = 0;

  for(res = res0; res; res = res->ai_next) {
    char hbuf[NI_MAXHOST];

    if (getnameinfo(res->ai_addr, res->ai_addrlen, hbuf, sizeof(hbuf), NULL,
        0, niflags) != 0) {
      strcpy(hbuf, "(invalid)");
    }

    i = socket(res->ai_family, res->ai_socktype, res->ai_protocol);

    if (i < 0)
      continue;

    if (connect(i, res->ai_addr, res->ai_addrlen) < 0) {
      char hbuf[NI_MAXHOST];

      if (getnameinfo(res->ai_addr, res->ai_addrlen, hbuf, sizeof(hbuf), NULL,
          0, niflags) != 0) {
        strcpy(hbuf, "(invalid)");
      }

      report(0, "Connect to %s failed: %s\n", hbuf, strerror(errno));

      close(i);
      i=-1;
      retry++;
      continue;
    }

    break;

  }

  freeaddrinfo(res0);

  if (i < 0) {
    report(0, "Connect failed after %d retries: %s\n", retry, strerror(errno));
    return -1;
  }

  c->sockfd = i;

  if (chan_SetNonblocking(c) < 0) {
    report(0, "Socket SetNonblocking Error\n");
    return -1;
  }

  c->uri = uri;

  return 0;
}
#else
int 
chan_connect(Channel *c, URI *uri)
{
  struct sockaddr_in cliaddr;
  struct hostent *netlocaddr;
  struct in_addr **hostptr;
  int connsucceed;

  if (!c)
    return -1;

  if ((netlocaddr = gethostbyname(uri->netloc)) == NULL) { 
    report(0, "Hostname lookup for %s failed", uri->netloc);
    return -1;
  }

  hostptr = (struct in_addr **) netlocaddr->h_addr_list;

  c->sockfd = socket(AF_INET, SOCK_STREAM, 0);

  for (connsucceed = 0 ; *hostptr != NULL; hostptr++) {
 
    memset(&cliaddr, 0, sizeof(cliaddr));
    cliaddr.sin_family = AF_INET;
    cliaddr.sin_port = htons((uri->port == 0) ? opencmport : uri->port);
    memcpy(&cliaddr.sin_addr, *hostptr, sizeof(struct in_addr));

    report(1, "Connecting to %s: %s, port %d\n", 
	   uri->netloc, inet_ntoa(cliaddr.sin_addr), ntohs(cliaddr.sin_port)); 

    connsucceed = 
      (connect(c->sockfd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) == 0);

    if (connsucceed)
      break;
  }

  if (!connsucceed)  
    return -1;

  if (chan_SetNonblocking(c) < 0) {
    report(0, "Socket Connection Error\n");
    return -1;
  }

  c->uri = uri;

  return 0;
}
#endif

