#include "tra.h"

#define curclient *threaddata()

static Client *freelist;
static Client *pending, **epending;

static Client*
allocclient(void)
{
	int i;
	Client *cli;

	if(freelist == nil){
		freelist = emalloc(sizeof(Client)*64);
		for(i=0; i<64-1; i++)
			freelist[i].next = &freelist[i+1];
		freelist[i].next = nil;
		for(i=0; i<64; i++)
			freelist[i].c = chan(Buf*);
	}
	cli = freelist;
	freelist = cli->next;
	return cli;
}

static void
freeclient(Client *cli)
{
	cli->next = freelist;
	freelist = cli;
}

static Client*
findclient(int tag)
{
	/*
	 * If this gets too expensive, it is possible though
	 * not as safe to do 
	 *	return (Client*)tag;
	 * note that this would require fiddling on
	 * 64-bit machines such as the alpha or ultrasparc
	 */
	Client **l, *t;

	for(l=&pending; *l; l=&(*l)->next){
		if((intptr)*l == tag){
			t = *l;
			*l = t->next;
			t->next = nil;
			if(*l == nil)
				epending = l;
			return t;
		}
	}
	return nil;
}

static void
queueclient(Client *cli)
{
	if(pending==nil)
		pending = cli;
	else
		*epending = cli;
	cli->next = nil;
	epending = &cli->next;
}

void
startclient(void)
{
	curclient = allocclient();
}

void
endclient(void)
{
	Client *cli;

	cli = curclient;
	curclient = nil;
	freeclient(cli);
}

char*
rpcerror(void)
{
	Client *cli;

	cli = curclient;
	return estrdup(cli->err);
}

/* BUG: this is a copy of replread */
void
replthread(void *a)
{
	uchar hdr[4], *x;
	int n, nn, tag;
	Buf *b, *bb;
	Client *cli;
	Fdbuf *fd;
	Replica *repl;

	repl = a;
	fd = openfdbuf(repl->rfd);
	for(;;){
		if(readnfdbuf(fd, hdr, 4) != 4){
			repl->err = "eof reading input";
			break;
		}
		n = LONG(hdr);
		dbg(DbgRpc, "%p looking for %d\n", repl, n);
		if(n > 1*1024*1024){
			repl->err = "implausible size";
			break;
		}
		b = mkbuf(nil, n);
		if(readnfdbuf(fd, b->p, n) != n){
			free(b);
			repl->err = "eof reading input";
			break;
		}
		dbg(DbgRpc, "%p got %d\n", repl, n);
		if(repl->inflate){
			inzrpctot += b->ep - b->p;
			dbg(DbgRpc, "flate %.*H\n", (int)(b->ep-b->p), b->p);
			nn = readbufl(b);
			if(nn > 1*1024*1024){
				repl->err = "implausible rpc packet";
				break;
			}
			bb = mkbuf(nil, nn);
			if(inflateblock(repl->inflate, bb->p, nn, b->p, b->ep-b->p) != nn){
				repl->err = "error decompressing block";
				break;
			}
			free(b);			
			b = bb;
			dbg(DbgRpc, "inflated %.*H\n", (int)(b->ep-b->p), b->p);
		}
		inrpctot += b->ep - b->p;
		if(b->ep < b->p+2+4){
			free(b);
			repl->err = "short rpc packet";
			break;
		}
		x = b->p+2;
		tag = LONG(x);
		cli = findclient(tag);
		dbg(DbgRpc, "%p got tag %d\n", repl, tag);
		if(cli == nil){
			free(b);
			repl->err = "rpc packet with unexpected tag";
			break;
		}
		sendp(cli->c, b);
	}
	dbg(DbgRpc, "repl closing buf %s\n", repl->err);
	closefdbuf(fd);
	if(repl->err == nil)
		repl->err = "unknown rpc error";
}

int nrpc;
int mrpc;

static int
dorpc(Replica *repl, Rpc *r)
{
	Buf *b;
	Chan *c;
	Client *cli;
	Rpc nr;

	cli = curclient;
	r->tag = (intptr)cli;	/* XXX: 64-bit machines */
	c = cli->c;
	b = convR2M(r);
	if(b == nil)
		abort();

	dbg(DbgRpc, "->%p %R\n", repl, r);
	queueclient(cli);
	nrpc++;
	if(nrpc > mrpc)
		mrpc = nrpc;
	if(replwrite(repl, b) < 0){
		rerrstr(cli->err, sizeof cli->err);
		free(b);
		return -1;
	}
	free(b);
	b = recvp(c);
	nrpc--;
	if(b == nil){
		rerrstr(cli->err, sizeof cli->err);
		free(b);
		return -1;
	}
	if(convM2R(b, &nr) < 0){
		rerrstr(cli->err, sizeof cli->err);
		free(b);
		return -1;
	}
	dbg(DbgRpc, "<-%p %R\n", repl, &nr);
	if(nr.type==Rerror){
		utfecpy(cli->err, cli->err+sizeof cli->err, nr.err);
		free(b);
		return -1;
	}
	if(nr.type != r->type+1){
		snprint(cli->err, sizeof cli->err, "bad tag %d expected %d", nr.type, r->type+1);
		free(b);
		return -1;
	}
	if(r->type==Tread || r->type==Treadhash){
		memmove(r->a, nr.a, nr.n);
		nr.a = r->a;
	}
	free(b);
	*r = nr;
	return 0;
}

int
rpcaddtime(Replica *repl, Path *p, Vtime *st, Vtime *mt)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Taddtime;
	r.p = p;
	r.st = st;
	r.mt = mt;
	if(dorpc(repl, &r) < 0)
		return -1;
	return 0;
}

int
rpcclose(Replica *repl, int fd)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tclose;
	r.fd = fd;
	if(dorpc(repl, &r) < 0)
		return -1;
	return 0;
}

int
rpccommit(Replica *repl, int fd, Stat *s)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tcommit;
	r.fd = fd;
	r.s = s;
	if(dorpc(repl, &r) < 0)
		return -1;
	return 0;
}

int
rpcdebug(Replica *repl, int debug)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tdebug;
	r.n = debug;
	return dorpc(repl, &r);
}

int
rpcflate(Replica *repl, int flate)
{
	Rpc r;
	Flate *inflate, *deflate;

	if((inflate = inflateinit()) == nil || (deflate = deflateinit(flate)) == nil){
		if(inflate)
			inflateclose(inflate);
		return -1;
	}

	memset(&r, 0, sizeof r);
	r.type = Tflate;
	r.n = flate;
	if(dorpc(repl, &r) < 0){
		inflateclose(inflate);
		deflateclose(deflate);
		return -1;
	}
	dbg(DbgRpc, "%p flate=%d\n", repl, flate);
	repl->inflate = inflate;
	repl->deflate = deflate;
	return 0;
}

int
rpchangup(Replica *repl)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Thangup;
	if(dorpc(repl, &r) < 0)
		return -1;
	return 0;
}

Hashlist*
rpchashfile(Replica *repl, int fd)
{
	Hashlist *hl;
	int i, n;
	uchar buf[8192];
	vlong off;

	hl = mkhashlist();
	off = 0;
	while((n = rpcreadhash(repl, fd, buf, sizeof buf)) > 0){
		if(n % (2+SHA1dlen)){
			werrstr("got bad readhash count %d", n);
			free(hl);
			return nil;
		}
		for(i=0; i<n; i+=2+SHA1dlen){
			hl = addhash(hl, buf+i+2, off, SHORT(buf+i));
			off += SHORT(buf+i);
		}
	}
	if(n < 0){
		free(hl);
		return nil;
	}
	return hl;
}

int
rpckids(Replica *repl, Path *p, Kid **pkid)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tkids;
	r.p = p;
	if(dorpc(repl, &r) < 0)
		return -1;
	*pkid = r.k;
	return r.nk;
}

char*
rpcmeta(Replica *repl, char *s)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tmeta;
	r.str = s;
	if(dorpc(repl, &r) < 0)
		return nil;
	return r.str;
}
int
rpcmkdir(Replica *repl, Path *p, Stat *s)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tmkdir;
	r.p = p;
	r.s = s;
	return dorpc(repl, &r);
}

int
rpcopen(Replica *repl, Path *p, char omode)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Topen;
	r.p = p;
	r.omode = omode;
	
	if(dorpc(repl, &r) < 0)
		return -1;
	return r.fd;
}
	
long
rpcread(Replica *repl, int fd, void *a, long n)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tread;
	r.fd = fd;
	r.a = a;
	r.n = n;
	if(dorpc(repl, &r) < 0)
		return -1;
	return r.n;
}

long
rpcreadhash(Replica *repl, int fd, void *a, long n)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Treadhash;
	r.fd = fd;
	r.a = a;
	r.n = n;
	if(dorpc(repl, &r) < 0)
		return -1;
	return r.n;
}

long
rpcreadn(Replica *repl, int fd, void *a, long n)
{
	long tot, m;

	for(tot=0; tot<n; tot+=m){
		m = rpcread(repl, fd, (uchar*)a+tot, n-tot);
		if(m < 0)
			break;
		if(m == 0){
			werrstr("early eof");
			break;
		}
	}
	return tot;
}

int
rpcreadonly(Replica *repl, int ignwr)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Treadonly;
	r.n = ignwr;
	if(dorpc(repl, &r) < 0)
		return -1;
	return 0;
}

int
rpcremove(Replica *repl, Path *p, Stat *s)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tremove;
	r.p = p;
	r.s = s;
	return dorpc(repl, &r);
}

int
rpcseek(Replica *repl, int fd, vlong off)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tseek;
	r.fd = fd;
	r.vn = off;
	return dorpc(repl, &r);
}

Stat*
rpcstat(Replica *repl, Path *p)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Tstat;
	r.p = p;
	if(dorpc(repl, &r) < 0)
		return nil;
	return r.s;
}

long
rpcwrite(Replica *repl, int fd, void *a, long n)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Twrite;
	r.fd = fd;
	r.a = a;
	r.n = n;
	if(dorpc(repl, &r) < 0)
		return -1;
	return r.n;
}

long
rpcwritehash(Replica *repl, int fd, void *a, long n)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Twritehash;
	r.fd = fd;
	r.a = a;
	r.n = n;
	if(dorpc(repl, &r) < 0)
		return -1;
	return r.n;
}

int
rpcwstat(Replica *repl, Path *p, Stat *s)
{
	Rpc r;

	memset(&r, 0, sizeof r);
	r.type = Twstat;
	r.p = p;
	r.s = s;
	return dorpc(repl, &r);
}

