/*
 * Copyright (c) 2002, The EROS Group, LLC and Johns Hopkins
 * University. All rights reserved.
 * 
 * This software was developed to support the EROS secure operating
 * system project (http://www.eros-os.org). The latest version of
 * the OpenCM software can be found at http://www.opencm.org.
 * 
 * Redistribution and use in source and binary forms, with or
 * without modification, are permitted provided that the following
 * conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 
 * 2. Redistributions in binary form must reproduce the above
 *    copyright notice, this list of conditions and the following
 *    disclaimer in the documentation and/or other materials
 *    provided with the distribution.
 * 
 * 3. Neither the name of the The EROS Group, LLC nor the name of
 *    Johns Hopkins University, nor the names of its contributors
 *    may be used to endorse or promote products derived from this
 *    software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
 * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <opencm.h>
#include <repos/opencmrepos.h>

static unsigned long request_id = 0;

#define make_request(x) do_make_request(OP_##x)
static Request *
do_make_request(unsigned int opcode)
{
  return request_create(opcode, ++request_id);  
}

#define GOOD_REPLY(r)  (r->rtype != NET_ERROR)
#define IS_INFOGRAM(r) (r->rtype == NET_INFO)

/* Method to encapsulate all network protocol-specific details
   of sending request to the Server */
Serializable * invoke_server (Repository *r, Request *);

typedef struct NetConnInfo NetConnInfo;
struct NetConnInfo {
  Channel *chan;
  int compressionLevel;
  unsigned protoVersion;
} ;

static void
netrepository_connect(Repository *r, PubKey *pk)
{
  Serializable *s;
  Channel *c;
  WireString *pathArg;
  WireString *id;
  WireUnsigned *svr_protoVersion;
  ObVec *ret;
  Request *sr = make_request(StartSession);  
  NetConnInfo *nci = GC_MALLOC(sizeof(NetConnInfo));

  /* First arg is what client thinks the repos path is */
  pathArg = wires_create(r->uri->path);
  request_AddArgument(sr, pathArg);

  /* Second arg is the highest protocol version client supports */
  request_AddArgument(sr, wireu_create(CM_PROTO_VER));

  c = chan_connect_ssl(r->uri);

  /* Check for TCP connect errors.  E.g. client specified wrong
   * port number, wrong server name, etc. */
  if (c == NULL)
    THROW(ExNoConnect, 
	  format("Couldn't connect to net repository at %s", r->uri->URI));

  /* Wait for the handshake to complete */
  chan_shake(c);

  assert(c->peerCert);

  nci->chan = c;
  nci->compressionLevel = 0;	/* until after handshake */

  r->info = nci;

  r->svrPubKey = c->peerCert;

  if (c == NULL)
    THROW(ExNoConnect, 
	  format("Couldn't connect to net repository at %s", r->uri->URI));

  ret = (ObVec *) invoke_server(r, sr);

  /* Extract the repos id from the server response */
  id = (WireString *)vec_fetch(ret, 0);

  /* Extract the lowest and highest protocol version the server supports */
  svr_protoVersion = (WireUnsigned *)vec_fetch(ret, 1);
  if (svr_protoVersion->value < CM_LEAST_PROTO_VER)
    THROW(ExBadValue, "Server no longer supports client protocol version. Upgrade the client.");

  nci->protoVersion = svr_protoVersion->value;
  r->repositoryID = id->value;

  repos_validate_pubkey(r->repositoryID, c->peerCert);

  r->authMutable = repos_GetUser(r, pk);
  s = repos_GetMutableContent(r, r->authMutable);

  r->authUser = (User *)s;
  r->authAccess = 0x0;  /* not needed in this case */

  if (opt_CompressionLevel)
    repos_SetCompression(r, "gzip", opt_CompressionLevel);
}

static void
netrepository_disconnect(Repository *r)
{
  /* Catch any exceptions here, to prevent endless loops at client level */
  TRY {
    Request *sr = make_request(Disconnect);  
    invoke_server(r, sr);
  }
  DEFAULT(anyex) {
  }
  END_CATCH;
}

/*
 *  Create a Request with the opcode for "get version",
 *  send it upstream and wait for the reply.
 */
static const char *
netrepository_GetVersion(Repository *r)
{
  const char *ver;
  WireString *repos_version;
  Request *sr = make_request(GetVersion);  

  repos_version = (WireString *) invoke_server(r, sr);
  
  ver = xstrdup(repos_version->value);

  return ver;
} 
  
/*
* Create a Request object with the opcode for "create mutable"
* and wait for the reply.
*/
static Mutable *
netrepository_CreateMutable(Repository *r, const char *nm, const char *desc,
			    void *s, unsigned flags)
{
  Mutable *m;
  Request *sr = make_request(CreateMutable);

  request_AddArgument(sr, wires_create(nm));
  request_AddArgument(sr, wires_create(desc));
  request_AddArgument(sr, (Serializable *)s);
  request_AddArgument(sr, wireu_create(flags));

  m = (Mutable *) invoke_server(r, sr);

  return m;
}

static Mutable *
netrepository_DupMutable(Repository *r, const char *nm,
			 const char *mURI, OC_bool keepACLs,
			 oc_uint64_t rev,
			 unsigned flags)
{
  Mutable *dupM;
  Request *sr = make_request(DupMutable);

  request_AddArgument(sr, wires_create(nm));
  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wireu_create(keepACLs));
  request_AddArgument(sr, wireu_create(rev));
  request_AddArgument(sr, wireu_create(flags));

  dupM = (Mutable *) invoke_server(r, sr);

  return dupM;
}

static Mutable *
netrepository_SetMutableFlags(Repository *r, const char *mURI, unsigned flags)
{
  Mutable *retMutable;
  Request *sr = make_request(SetMutableFlags);

  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wireu_create(flags));

  retMutable = (Mutable *) invoke_server(r, sr);

  return retMutable;
}

static Mutable *
netrepository_SetMutableName(Repository *r, const char *mURI, const char *name)
{
  Mutable *retMutable;
  Request *sr = make_request(SetMutableName);

  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wires_create(name));

  retMutable = (Mutable *) invoke_server(r, sr);

  return retMutable;
}

static Mutable *
netrepository_SetMutableDesc(Repository *r, const char *mURI, const char *desc)
{
  Mutable *retMutable;
  Request *sr = make_request(SetMutableDesc);

  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wires_create(desc));

  retMutable = (Mutable *) invoke_server(r, sr);

  return retMutable;
}

static Mutable *
netrepository_GetMutable(Repository *r, const char *mName)
{
  Mutable *m;
  WireString *ser_mName;
  Request *sr = make_request(GetMutable);

  if (mName == NULL)
    THROW(ExNullArg, "No mutable name provided to netrepository_GetMutable.\n");

  ser_mName = wires_create(mName);

  /* Now add the name of the Mutable (uri) as an argument to this request: */
  request_AddArgument(sr, ser_mName);

  m = (Mutable *) invoke_server(r, sr);

  return m;
}

static Mutable *
netrepository_ReviseMutable(Repository *r, const char *mURI, 
			    oc_uint64_t curTopRev, void *v)
{
  Mutable *retMutable;
  Request *sr = make_request(ReviseMutable);

  if (mURI == NULL)
    THROW(ExNullArg, "No mutable to netrepository_ReviseMutable!\n");

  /* Now add the Mutable as an argument to this request: */
  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wireu_create(curTopRev));
  request_AddArgument(sr, v);

  retMutable = (Mutable *) invoke_server(r,sr);

  return retMutable;
}

static Revision *
netrepository_GetRevision(Repository *r, const char *mURI, 
			  oc_uint64_t revNo)
{
  Revision *rev;

  Request *sr = make_request(GetRevision);

  if (mURI == NULL)
    THROW(ExNullArg, "No mutable to netrepository_GetRevisions!\n");

  /* Now add the Mutable as an argument to this request: */
  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wireu_create(revNo));

  rev = (Revision *) invoke_server(r,sr);

  return rev;
}

static void
netrepository_PutRevision(Repository *r, Revision *rev)
{
  Request *sr = make_request(PutRevision);

  /* Now add the Mutable as an argument to this request: */
  request_AddArgument(sr, rev);

  invoke_server(r,sr);
}

static Mutable *
netrepository_SetMutableACL(Repository *r, const char *mURI, unsigned which, 
			    const char *aclURI)
{
  Request *sr = make_request(SetMutableACL);
  Mutable *retMutable;

  if (mURI == NULL)
    THROW(ExNullArg, "No mutable to netrepository_SetMutableACL.\n");

  if (aclURI == NULL)
    THROW(ExNullArg, "No uri to netrepository_SetMutableACL.\n");

  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wireu_create(which));
  request_AddArgument(sr, wires_create(aclURI));

  retMutable = (Mutable *) invoke_server(r,sr);

  return retMutable;
}

static void
netrepository_ReviseEntity(Repository *r,
			   const char *mURI,
			   const char *tName, void *serobject) 
{ 
  Request *sr = make_request(ReviseEntity);
  URI *uri = uri_create(mURI);

  if (serobject == NULL)
    THROW(ExNullArg, "No object sent to netrepository_ReviseEntity.\n");
 
  if (!nmequal(uri->netloc, r->repositoryID))
    THROW(ExNoAccess, "Can't revise on non-owning repository.\n");

  /* Now add the arguments */
  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wires_create(tName));
  request_AddArgument(sr, serobject);

  invoke_server(r,sr);
} 

static Mutable *
netrepository_BindUser(Repository *r, PubKey *pk, unsigned access)
{
  Mutable *m;
  Request *sr = make_request(BindUser);

  request_AddArgument(sr, pk);
  request_AddArgument(sr, wireu_create(access));

  m = (Mutable *) invoke_server(r, sr);

  return m;
}

static Mutable *
netrepository_GetUser(Repository *r, PubKey *pk)
{
  Mutable *user;
  Request *sr = make_request(GetUser);

  request_AddArgument(sr, pk);

  user = (Mutable *) invoke_server(r, sr);

  return user;
}

static unsigned int
netrepository_GetUserAccess(Repository *r, PubKey *pk)
{
  WireUnsigned *access = NULL;
  Request *sr = make_request(GetUserAccess);

  request_AddArgument(sr, pk);

  access = (WireUnsigned *) invoke_server(r, sr);

  return access->value;
}

/*
 * Return a pointer to an Entity, given the Repository and
 * the Entity's true name as input.
 */
static void *
netrepository_GetEntity(Repository *r, const char *mURI, const char *tName)
{
  Serializable *s;
  Request *sr = make_request(GetEntity);

  /* Now add the true name of the Project as an argument to this request: */
  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wires_create(tName));

  s = (Serializable *) invoke_server(r,sr);

  return s;
}

static void
netrepository_ShutDown(Repository *r)
{
  Request *sr = make_request(ShutDown);

  (void) invoke_server(r, sr);
}

static TnVec *
netrepository_GetParents(Repository *r, const char *mURI, const char *tName)
{
  TnVec *tv;
  Request *sr = make_request(GetParents);

  /* Now add the true name of the Project as an argument to this request: */
  request_AddArgument(sr, wires_create(mURI));
  request_AddArgument(sr, wires_create(tName));

  tv = (TnVec *) invoke_server(r,sr);

  return tv;
}

static void
netrepository_SetCompression(Repository *r, const char *method, unsigned level)
{
  NetConnInfo *nci = (NetConnInfo *)r->info;

  /* Silently ignore this request if the protocol versoin doesn't
     support the SetCompression operation. */

#if CM_LEAST_PROTO_VER >= 4
#error "The following test is now obsolete and can be removed."
#endif

  if (nci->protoVersion >= 4) {
    Request *sr = make_request(SetCompression);

    /* Now add the true name of the Project as an argument to this request: */
    request_AddArgument(sr, wires_create(method));
    request_AddArgument(sr, wireu_create(level));

    invoke_server(r,sr);

    nci->compressionLevel = level;
  }
}

Serializable * 
invoke_server (Repository *r, Request *request) 
{

  SDR_stream* net_stream;
  char *rawReply;
  Reply *sreply = NULL;
  reqlen_t total;
	
  NetConnInfo *nci = (NetConnInfo *)r->info;

  if (opt_TraceProtocol) {
    report(0, "<-- Begin Request: -->\n");
    request_trace(request);
    report(0, "<-- End   Request  -->\n");
  }

  if (!nci->chan)
    THROW(ExConnLost, "Network connection was lost");

  net_stream = stream_createBuffer(SDR_WIRE);

  /* Now serialize the request: */
  sdr_write (SERVERREQUEST, net_stream, request);
	  
  /* First, send the message size across the wire, then send the
   * actual message: */
  {
    Buffer *reqBuf = stream_asBuffer(net_stream);
    ocmoff_t end, pos;

    if (nci->compressionLevel)
      reqBuf = buffer_compress(reqBuf, nci->compressionLevel);

    end = buffer_length(reqBuf);
    pos = 0;

    total = end;
    total = htonl(total);

    TRY {
      chan_write(nci->chan, &total, sizeof(total), NULL);
    }
    DEFAULT(ex) {
      stream_close(net_stream);
      RETHROW(ex);
    }
    END_CATCH;

    while (pos < end) {
      BufferChunk bc = buffer_getChunk(reqBuf, pos, end - pos);
      assert(bc.len <= (end - pos));

      TRY {
	chan_write(nci->chan, bc.ptr, bc.len, NULL);
      }
      DEFAULT(ex) {
	stream_close(net_stream);
	RETHROW(ex);
      }
      END_CATCH;
      
      pos += bc.len;
    }
  }

  stream_close(net_stream);
  /* Use infinite loop to allow server to respond with 'info grams'
   * while we're waiting on the actual Reply: */
  for(;;) {
    Buffer *buf = buffer_create();
      
    total = 0u;
    TRY {
      chan_read(nci->chan, &total, sizeof(total), NULL);
    }
    DEFAULT(ex) {
      xprintf("Could not read message length from server");
      RETHROW(ex);
    }
    END_CATCH;

    /* Now that we know the length of the reply, try
     * to read the actual reply:
     */
    total = ntohl(total);

    rawReply = (void *)GC_MALLOC_ATOMIC(total);
    memset(rawReply, 0, total);

    TRY {
      chan_read(nci->chan, (void *)rawReply, total, NULL);
    }
    DEFAULT(ex) {
      xprintf("Could not read response from server\n");
      RETHROW(ex);
    }
    END_CATCH;

    buffer_append(buf, rawReply, total);

    if (nci->compressionLevel)
      buf = buffer_decompress(buf);

    net_stream = stream_fromBuffer(buf, SDR_WIRE);
    sreply = sdr_read(SERVERREPLY, net_stream);
    stream_close(net_stream);

    if (sreply == NULL)
      THROW(ExNullArg, "Null/garbled reply from server");

    if (opt_TraceProtocol) {
      report(0, "<-- Begin Response: -->\n");
      reply_trace(sreply);
      report(0, "<-- End   Response  -->\n");
    }

    if (sreply->reqID != request->reqID)
      THROW(ExBadValue, "Response ID does not match request ID");

    /* FIX: Infograms will eventually be handled by a distinguished
       serializable object type. */

    break;
  }

  /* Here's where we need to translate a wire-error code into
   * an appropriate Exception to pass back up to the client */
  if (sreply->value && GETTYPE(sreply->value) == TY_WireException) {
    WireException *ex = (WireException *)sreply->value;
    _throw(_curCatch, xstrcat("(server)", ex->fname), ex->line, ex->name, ex->str);
  }

  return sreply->value;
} 

void
netrepository_init(Repository * r) 
{
  r->doesAccess        = TRUE;

  r->GetVersion        = netrepository_GetVersion;
  r->Connect           = netrepository_connect;
  r->Disconnect        = netrepository_disconnect;

  /* Entity management: -- both 2nd args must be serializables */
  r->ReviseEntity      = netrepository_ReviseEntity;
  r->GetEntity         = netrepository_GetEntity;

  /* Mutable management: */
  r->GetMutable        = netrepository_GetMutable;
  r->CreateMutable     = netrepository_CreateMutable;
  r->DupMutable        = netrepository_DupMutable;
  r->ReviseMutable     = netrepository_ReviseMutable;
  r->GetRevision       = netrepository_GetRevision;
  r->PutRevision       = netrepository_PutRevision;
  r->SetMutableACL     = netrepository_SetMutableACL;
  r->SetMutableFlags   = netrepository_SetMutableFlags;
  r->SetMutableName    = netrepository_SetMutableName;
  r->SetMutableDesc    = netrepository_SetMutableDesc;

  r->BindUser           = netrepository_BindUser;
  r->GetUser            = netrepository_GetUser;
  r->GetUserAccess      = netrepository_GetUserAccess;
  r->ShutDown           = netrepository_ShutDown;
  r->GetParents         = netrepository_GetParents;
  r->SetCompression     = netrepository_SetCompression;
}
