/*
 * Copyright (c) 2002, The EROS Group, LLC and Johns Hopkins
 * University. All rights reserved.
 * 
 * This software was developed to support the EROS secure operating
 * system project (http://www.eros-os.org). The latest version of
 * the OpenCM software can be found at http://www.opencm.org.
 * 
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 * 
 * 3. Neither the name of the The EROS Group, LLC nor the name of
 *    Johns Hopkins University, nor the names of its contributors
 *    may be used to endorse or promote products derived from this
 *    software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <opencm.h>

extern int chan_SetNonblocking(Channel *);
extern int chan_connect(Channel *, URI *);

static void tcp_send(Channel *c, const void *buf, size_t len, ChannelCompletionFn fn);
static void tcp_flush(Channel *c, OC_bool andWait);
static void tcp_receive(Channel *c, void *buf, size_t len, ChannelCompletionFn fn);
static void tcp_pull(Channel *c, OC_bool andWait);
static void tcp_close(Channel *c);
static void tcp_aclose(Channel *c);

static Channel *
chan_init_tcp(ChannelCompletionFn fn)
{
  Channel *c = chan_create();
  c->send = tcp_send;
  c->receive = tcp_receive;
  c->flush = tcp_flush;
  c->pull = tcp_pull;
  c->close = tcp_close;
  c->aclose = tcp_aclose;
  if (fn)
    c->rCallback = fn;

  return c;
}

Channel *
chan_connect_tcp(URI *uri) 
{
  Channel *c = chan_init_tcp(NULL);
  if (chan_connect(c, uri) < 0)
    return NULL;
  else
    return c;
}

Channel *
chan_alisten_tcp(ChannelCompletionFn fn, char *host, unsigned short port)
{
#define BACKLOG 10
#ifdef HAVE_GETADDRINFO
  struct addrinfo hints, *ai;
  char strport[NI_MAXSERV];
  int gaierr;
#else
  struct sockaddr_in servaddr;
#endif
  int val = 1;
  Channel *c;

#ifdef HAVE_GETADDRINFO
  memset(&hints, 0, sizeof(hints));
  hints.ai_family = AF_UNSPEC;
  hints.ai_socktype = SOCK_STREAM;
  hints.ai_flags = (host == NULL) ? AI_PASSIVE : 0;
  snprintf(strport, sizeof strport, "%u", port);

  if ((gaierr = getaddrinfo(host, strport, &hints, &ai)) != 0) {
    report(0, "bad addr or host: %s (%s)", (host == NULL) ? "<NULL>" : host,
	   gai_strerror(gaierr));
    return NULL;
  }
#endif

  /* XXX ? memory leak if not free'd before 'return NULL' ? */
  c = chan_init_tcp(fn);

#ifdef HAVE_GETADDRINFO
  if ((c->sockfd = socket(ai->ai_family, SOCK_STREAM, 0)) < 0)
    return NULL;
#else
  if ((c->sockfd = socket(AF_INET, SOCK_STREAM, 0)) < 0)
    return NULL;
#endif

  /* Ensure that client connections get closed when we run
     subprocesses. */
  fcntl(c->sockfd, F_SETFD, FD_CLOEXEC | fcntl(c->sockfd, F_GETFD));

  setsockopt(c->sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));

#ifdef HAVE_GETADDRINFO
  if (bind(c->sockfd, ai->ai_addr, ai->ai_addrlen) < 0)
    return NULL;
#else
  if (bind(c->sockfd, (struct sockaddr *) &servaddr, sizeof(servaddr)) < 0)
    return NULL;
#endif

  if (listen(c->sockfd, BACKLOG) < 0)
    return NULL;

  /* This will make the listen socket non-blocking: */
  if (chan_SetNonblocking(c) < 0)
    return NULL;

  return c;
}

Channel *
chan_accept_tcp(Channel *c)
{
  int fd = 0;
  struct sockaddr_in cliaddr;
  int clientlen;
  Channel *client;

  memset(&cliaddr, 0, sizeof(cliaddr));
  clientlen = sizeof(cliaddr);
  fd = accept(c->sockfd, (struct sockaddr *) &cliaddr, &clientlen);
  if (fd < 0) {
    if (errno != EWOULDBLOCK) {
      report(0, "Server Accept failed from %s\n",
	inet_ntoa(cliaddr.sin_addr));
    }
    return NULL;
  }
  
  /* Ensure that client connections get closed when we run
     subprocesses. */
  fcntl(fd, F_SETFD, FD_CLOEXEC | fcntl(fd, F_GETFD));

  debug(2, "Processing request from %s\n",
      inet_ntoa(cliaddr.sin_addr));

  client = chan_init_tcp(NULL);
  client->sockfd = fd;
  if (chan_SetNonblocking(client) < 0) {
    report(0, "Server couldn't make non-blocking Channel for %s\n",
      inet_ntoa(cliaddr.sin_addr));
    return NULL;
  }

  return client;

}

static void
tcp_flush(Channel *c, OC_bool blocking)
{
  size_t sent = 0;

  while (c->writeQueue) {
    PendingWrite *pw = c->writeQueue;

    while (pw->sent < pw->len) {
      TRY {
	sent = write(c->sockfd, pw->buf + pw->sent, pw->len - pw->sent);
      } /* end of the TRY block */
      CATCH(ExConnLost) {
	c->aclose(c);
	c->writeQueue = 0;
      }
      END_CATCH;

      if (c->closed)
	return;

      assert(sent != 0);

      if (sent < 0 && errno == EAGAIN) {
	if (blocking) {
	  chan_blockForWriting(c, NULL);
	  continue;
	}
	else
	  return;
      }
      else if (sent < 0) {
	c->aclose(c);
	c->writeQueue = 0;
	THROW(ExIoError, format("I/O error on channel %d write: %s", 
				c->sockfd, strerror(errno)));
	return;
      }

      /* Process what we got and get some more. */

      pw->sent += sent;
    }

    {
      ChannelCompletionFn fn = pw->callback;
      pw->callback = NULL;
      fn(c);
      c->writeQueue = pw->next;
    }
  }
}

static void
tcp_pull(Channel *c, OC_bool blocking)
{
  size_t nread = 0;

  if (!c->rCallback) return;

  /* This is a bit misleading. We are really checking for connection
     completion here. total_read is actually the *length* of the
     pending read request. Zero signals a connect (which should be
     cleaned up). */
  if (c->total_read == 0) {
    if (blocking) chan_blockForReading(c, NULL);
  }
  else {
    while(c->read_so_far < c->total_read) {
      TRY {
	nread = read(c->sockfd, c->read_buffer + c->read_so_far, 
		     c->total_read - c->read_so_far);
      }
      CATCH(ExConnLost) {
	c->aclose(c);
      }
      END_CATCH;

      if (nread == 0) {
	c->aclose(c);
	return;
      }
      else if (nread < 0) {
	if (errno == EAGAIN) {
	  if (blocking) {
	    chan_blockForReading(c, NULL);
	    continue;
	  }
	  else
	    return;
	}
	else {
	  c->aclose(c);
	  THROW(ExIoError, format("I/O error on channel %d read: %s", 
				  c->sockfd, strerror(errno)));
	}
      }

      c->read_so_far += nread;
    }
  }

  if (!c->closed) {
    ChannelCompletionFn fn = c->rCallback;
    c->rCallback = NULL;
    fn(c);
  }
}

/**
  * Create a PendingWrite object and queue it on this Channel's
  * outbound (write) queue.
  */
static void
tcp_send(Channel *c, const void *buf, size_t len, ChannelCompletionFn fn) 
{
  PendingWrite *pw = (PendingWrite *)GC_MALLOC(sizeof(PendingWrite));

  if (!fn) fn = NullChannelCallback;

  pw->buf = buf;
  pw->len = len;
  pw->sent = 0;
  pw->callback = fn;

  if (c->writeQueue == 0) {
    c->writeQueue = pw;
  }
  else {
    PendingWrite *q = c->writeQueue;
    while (q->next)
      q = q->next;
    q->next = pw;
  }

} 

static void
tcp_receive(Channel *c, void *buf, size_t len, ChannelCompletionFn fn) 
{
  assert(c->rCallback == 0);

  if (!fn) fn = NullChannelCallback;

  c->rCallback = fn;
  c->read_buffer = buf;
  c->total_read = len;
  c->read_so_far = 0;
} 

static void
tcp_aclose(Channel *c)
{
  report(1, "Closing (async) TCP connection (connection id=%d, heap=%d)\n",
	 c->connection_id, GC_get_heap_size());
  c->closed = TRUE;
}

static void
tcp_close(Channel *c)
{
  report(1, "Shutting down TCP connection (connection id=%d, heap=%d)\n", 
	 c->connection_id, GC_get_heap_size());
  c->rCallback = NULL;
  c->writeQueue = NULL;
  TRY { 
    close(c->sockfd);
  }
  CATCH(ExConnLost) {
    /* do nothing */
  }
  END_CATCH;
}
