/*
 * Copyright (c) 2002, The EROS Group, LLC and Johns Hopkins
 * University. All rights reserved.
 * 
 * This software was developed to support the EROS secure operating
 * system project (http://www.eros-os.org). The latest version of
 * the OpenCM software can be found at http://www.opencm.org.
 * 
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 * 
 * 3. Neither the name of the The EROS Group, LLC nor the name of
 *    Johns Hopkins University, nor the names of its contributors
 *    may be used to endorse or promote products derived from this
 *    software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <opencm.h>
#include "opencmserver.h"
#include "../../gc6.3/include/gc_backptr.h"

char *storeReposURI = NULL;
extern Repository *localRepos;

/* prototypes: */
void svr_DoRequest(Channel *c, Request *);
int svr_OpenRepos(const char *);

/* All this OPCODE stuff is an attempt to make the processing
 * of client requests somewhat more manageable.  It sets up a
 * table of details on each opcode, including an index and number
 * and types of incoming arguments needed to complete the 
 * request.  FIX:  For completeness, we should encode the return
 * arguments and types and then we could ideally get rid of the 
 * switch/case statement in svr_DoRequest().  */

struct OpcodeType opcodes[OP_total] = {
#define OPCODE(nm,args,a1,a2,a3,a4,a5) \
  { OP_##nm, #nm, args, {a1,a2,a3,a4,a5} },
#include "../common/OPCODES.def"
};

/* variables to store incoming arguments */
struct Arguments {
  User   *user;
  PubKey *pubkey;
  Mutable *mut;
  Mutable *dmut;
  WireString *string;
  WireString *string2;
  WireString *string3;
  WireUnsigned *uns;
  WireUnsigned *uns2;
  WireUnsigned *uns3;
  WireDelta *wd;
  Serializable *ser;
  StrVec *names;
};
typedef struct Arguments Arguments;

#ifdef INFOGRAMS
/* Used to send information-only messages back to client
 * while server is processing client's request */
static void
svr_SendInfoGram(Channel *c, const char *msg, int id)
{
  Reply *reply;
  SDR_stream *reply_strm;
  reqlen_t total = 0;
  void *message;

  reply = reply_create(NET_INFO, msg, id);
  reply_strm = stream_openstring(STREAM_WIRE);
  sdr_write(SERVERREPLY, reply_strm, reply);
  total = reply_strm->len;

  /* Make one message containing the net-order of the size followed
   * by the actual bytes of the message: */
  total = htonl(total);
  message = (void *)GC_MALLOC_ATOMIC(sizeof(total) + reply_strm->len);
  memcpy(message, &total, sizeof(total));
  memcpy(message+sizeof(total), reply_strm->data, reply_strm->len);
  chan_awrite(c, message, sizeof(total)+reply_strm->len, NULL);
}
#endif

/* Extract appropriate arguments from the incoming request.  If
 * there aren't enough arguments provided, return FALSE. */
static void
svr_GetArgs(Request *req, Arguments *a)   /* throws EX_Args */
{
  struct OpcodeType *op = &opcodes[req->opcode];

  memset(a, 0, sizeof(*a));

  if (vec_size(req->args) == op->numargs) {
    unsigned u;
    for (u = 0; u < op->numargs; u++) {
      switch (op->intypes[u]) {
	case TY_Null:
	  {
	    break;
	  }
	case TY_Serializable:
	  {
	    a->ser = vec_fetch(req->args, u, Serializable *);
	    break;
	  }
	case TY_WireString:
	  {
	    if (a->string == NULL) {
	      a->string = vec_fetch(req->args, u, WireString *);
	    } else if (a->string2 == NULL) {
	      a->string2 = vec_fetch(req->args, u, WireString *);
	    }
	    else {
	      a->string3 = vec_fetch(req->args, u, WireString *);
	    }
	    break;
	  }
	case TY_WireUnsigned:
	  {
	    if (a->uns == NULL) {
	      a->uns = vec_fetch(req->args, u, WireUnsigned *);
	    } else if (a->uns2 == NULL) {
	      a->uns2 = vec_fetch(req->args, u, WireUnsigned *);
	    } else {
	      a->uns3 = vec_fetch(req->args, u, WireUnsigned *);
	    }
	    break;
	  }
	case TY_WireDelta:
	  {
	    a->wd = vec_fetch(req->args, u, WireDelta *);
	    break;
	  }
	case TY_User:
	  {
	    a->user = vec_fetch(req->args, u, User *);
	    break;
	  }
	case TY_PubKey:
	  {
	    a->pubkey = vec_fetch(req->args, u, PubKey *);
	    break;
	  }
	case TY_Mutable:
	  {
	    a->mut = vec_fetch(req->args, u, Mutable *);
	    break;
	  }
	case TY_StrVec:
	  {
	    a->names = vec_fetch(req->args, u, StrVec *);
	    break;
	  }
	default:
	  {
	    THROW(ExBadValue, "Bad argument type");
	    break;
	  }
      }
    }
  } else { 
    THROW(ExMalformed, "Malformed network request");
  }
}

extern void svr_Cleanup(void);

/* Execute client's request by generating an appropriate
 * Reply message and/or any necessary Infograms. */
void 
svr_DoRequest(Channel *c, Request *req) 
{
  Reply *reply;
  OC_bool shouldQuit = FALSE;

  static uint32_t nRequests = 0;

  SDR_stream *reply_strm;

  Arguments a;
  Session *client = (Session *)c;
  Repository *authRepos = client->repos;

  /* Need to cache the compression level, because it may change in
     SetCompression(), but must not take effect until after the reply
     is transmitted. */
  int compressionLevel = client->compressionLevel;

  if (opt_TraceProtocol) {
    log_trace(DBG_PROTO, "<-- Processing request: -->\n");
    log_trace(DBG_PROTO, "Connection:      %d\n", c->connection_id);
    request_trace(req);
    log_trace(DBG_PROTO, "<-- End        request  -->\n");
  }

  TRY {
    svr_GetArgs(req, &a);

    /* FIX:  Need a nice command matrix for this switch-case-block */
    /* Now, execute the appropriate command and generate a reply. */
    switch (req->opcode) {

    case OP_Disconnect:
      {
	reply = reply_create(req->reqID, 0);
	break;
      }

    case OP_GetVersion:
      {
	Serializable *s = (Serializable *)
	  wires_create(repos_GetVersion(authRepos));
	reply = reply_create(req->reqID, s);
	break;
      }

    case OP_CreateMutable:
      {
	Mutable *m = 
	  repos_CreateMutable(authRepos, a.string->value, a.string2->value,
	                      a.ser, a.uns->value); 
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_DupMutable:
      {
	Mutable *m =
	  repos_DupMutable(authRepos, a.string->value, a.string2->value, 
			   a.uns->value, a.uns2->value, a.uns3->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_GetMutable:
      {
	Mutable *m =
	  repos_GetMutable(authRepos, a.string->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_ReviseMutable:
      {
	Mutable *m = 
	  repos_ReviseMutable(authRepos, a.string->value,
			      a.uns->value, a.ser);

	if (!nmequal(opt_Notify, "/dev/null")) {
	  SubProcess *mailProc;
	  const char * svrPem = path_join(authRepos->uri->path, "config");
	  svrPem = path_join(svrPem, "server");
	  mailProc = subprocess_create();
	  subprocess_AddArg(mailProc, appInvokedName);
	  subprocess_AddArg(mailProc, "--repository");
	  subprocess_AddArg(mailProc, "opencm://localhost");
	  subprocess_AddArg(mailProc, "--notify");
	  subprocess_AddArg(mailProc, opt_Notify);
	  subprocess_AddArg(mailProc, "-u");
	  subprocess_AddArg(mailProc, svrPem);
	  subprocess_AddArg(mailProc, "logmail");
	  subprocess_AddArg(mailProc, 
			    path_join(a.string->value, 
				      xunsigned64_str(a.uns->value)));
	  subprocess_Run(mailProc, "/dev/null", 0, 0, SPF_DISCONNECT);
	}

	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_GetRevision:
      {
	Revision *rev =
	  repos_GetRevision(authRepos, a.string->value, a.uns->value);
	reply = reply_create(req->reqID, (Serializable *)rev);
	break;
      }

    case OP_PutRevision:
      {
	repos_PutRevision(authRepos, (Revision *) a.ser);
	reply = reply_create(req->reqID, 0);
	break;
      }

    case OP_SetMutableACL:
      {
	Mutable *m = 
	  repos_SetMutableACL(authRepos, a.string->value,
			      a.uns->value, a.string2->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }
      
    case OP_SetMutableFlags:
      {
	Mutable *m =
	  repos_SetMutableFlags(authRepos, a.string->value,
				a.uns->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }
      
    case OP_SetMutableName:
      {
	Mutable *m =
	  repos_SetMutableName(authRepos, a.string->value, a.string2->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }
      
    case OP_SetMutableDesc:
      {
	Mutable *m =
	  repos_SetMutableDesc(authRepos, a.string->value, a.string2->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }
      
    case OP_GetEntity:
      {
	Serializable *e =
	  repos_GetEntity(authRepos, a.string->value, 
			  a.string2->value);
	reply = reply_create(req->reqID, (Serializable *)e);
	break;
      } 

    case OP_GetEntityDelta:
      {
	WireDelta *wd =
	  repos_GetEntityDelta(authRepos, a.string->value, 
			       a.string2->value, a.string3->value);
	reply = reply_create(req->reqID, (Serializable *)wd);
	break;
      } 

    case OP_ReviseEntity:
      {
	repos_ReviseEntity(authRepos, a.string->value,
			   a.string2->value, a.ser);
	reply = reply_create(req->reqID, 0);
	break;
      } 

    case OP_ReviseEntityDelta:
      {
	repos_ReviseEntityDelta(authRepos, a.string->value, 
				a.string2->value, a.wd);
	reply = reply_create(req->reqID, 0);
	break;
      } 

    case OP_RebindUser:
      {
	Mutable *m =
	  repos_RebindUser(authRepos, a.string->value, a.pubkey);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_BindUser:
      {
	Mutable *m =
	  repos_BindUser(authRepos, a.pubkey, a.uns->value);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_GetUser:
      {
	Mutable *m =
	  repos_GetUser(authRepos, a.pubkey);
	reply = reply_create(req->reqID, (Serializable *)m);
	break;
      }

    case OP_GetUserAccess:
      {
	unsigned int access = repos_GetUserAccess(authRepos, a.pubkey);
	Serializable *s = (Serializable *) wireu_create(access);
	reply = reply_create(req->reqID, s);
	break;
      }

    case OP_ShutDown:
      {
	/* Perform an orderly shutdown of the repository. */
	repos_ShutDown(authRepos);
	reply = reply_create(req->reqID, NULL);
	shouldQuit = TRUE;
	break;
      }

    case OP_GetParents:
      {
	TnVec *tv =
	  repos_GetParents(authRepos, a.string->value, 
			   a.string2->value);
	reply = reply_create(req->reqID, (Serializable *)tv);
	break;
      } 

    case OP_SetCompression:
      {
	if (!nmequal(a.string->value, "gzip"))
	  THROW(ExBadValue, 
		format("Unknown compression type \"%s\"", a.string->value));
	if (a.uns->value > 9)
	  THROW(ExBadValue, 
		format("Compression level %s is invalid", 
		       xunsigned64_str(a.uns->value)));

	client->compressionLevel = a.uns->value;

	reply = reply_create(req->reqID, NULL);
	break;
      } 

    case OP_Enumerate:
      {
	TnVec *tv =
	  repos_Enumerate(authRepos, a.string->value, 
			  a.uns->value, a.uns2->value);
	reply = reply_create(req->reqID, (Serializable *)tv);
	break;
      }

    case OP_GarbageCollect:
      {
	repos_GarbageCollect(authRepos, a.uns->value);
	reply = reply_create(req->reqID, NULL);
	break;
      }

    default:
      {
	THROW(ExBadOpcode, "Unknown wire protocol request");
	break;
      }
    } 
  }

  DEFAULT(AnyException) {
    /* Translate any Exception to a code that can be
     * transmitted back over the wire */
    WireException *ex =
      wireExcpt_create(AnyException, _catch.fname, _catch.line, _catch.str);
    reply = reply_create(req->reqID, (Serializable *)ex);
  }
  END_CATCH;

  if (opt_TraceProtocol) {
    log_trace(DBG_PROTO, "<-- Begin Response: -->\n");
    reply_trace(reply);
    log_trace(DBG_PROTO, "<-- End   Response  -->\n");
  }

  /* Now convert reply to a raw message and send it back
   * via the Channel
   */
  reply_strm = stream_createBuffer(SDR_WIRE);
  sdr_write(SERVERREPLY, reply_strm, reply);

  /* We are making potentially many copies here. Nullify reply pointer
     so that state can be GC'd: */
  reply = 0;
  
  /* Make one message containing the net-order of the size followed
   * by the actual bytes of the message: */
  {

    /* Make sure this buffer is malloc'd:  chan_awrite() relies
     * on the buffer pointer's validity well beyond the scope of this
     * local declaration block. */
    char *totbuf = (char *)GC_MALLOC_ATOMIC(sizeof(reqlen_t));

    Buffer *buf = stream_asBuffer(reply_strm);
    ocmoff_t end, pos;
    reqlen_t total;

    if (compressionLevel)
    {
      buf = buffer_compress(buf, compressionLevel);
      assert(buffer_length(buf) > 0);
    }

    end = buffer_length(buf);
    pos = 0;
    total = end;

    total = htonl(total);
    memcpy(totbuf, &total, sizeof(total));

    chan_awrite(c, totbuf, sizeof(total), NULL);

    while (pos < end) {
      BufferChunk bc = buffer_getChunk(buf, pos, end - pos);
      assert(bc.len <= (end - pos));

      chan_awrite(c, bc.ptr, bc.len, NULL);

      pos += bc.len;
    }
    buf = 0;			/* GC */
  }

  stream_close(reply_strm);
  reply_strm = 0;		/* GC */

  log_trace(DBG_PROTO, "<-- Response Transmitted -->\n");

#if 0
  report(1, "GC heap size: %d\n", GC_get_heap_size());
  log_error("Bytes alloc since last GC: %d\n", GC_get_bytes_since_gc());
#endif

  nRequests ++;

  if (shouldQuit)
    svr_Cleanup();
} 
