#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <ctype.h>
#include <dlfcn.h>
#include <fcntl.h>
#include <stdarg.h>
#include <assert.h>
#include <dirent.h>
#include <inttypes.h>

#include <sys/time.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/types.h>
#include <sys/signal.h>
#include <sys/syscall.h>

#include <netinet/in.h>
#include <netinet/tcp.h>

#include <xenctrl.h>

#include "xennerctrl.h"
#include "evtchnd.h"
#include "list.h"

/* ------------------------------------------------------------------ */

#ifndef HAVE_EVTCHN_ERROR
# define evtchn_port_or_error_t evtchn_port_t
#endif

/* ------------------------------------------------------------------ */

struct evtpriv {
    int                      evtchnd;
    pthread_mutex_t          chnlock;
    struct list_head         list;
};
static LIST_HEAD(privs);
static pthread_mutex_t privlock;

/* ------------------------------------------------------------------ */
  
static struct evtpriv *getpriv(int evtchnd)
{
    struct list_head *item;
    struct evtpriv *p;

    pthread_mutex_lock(&privlock);
    list_for_each(item, &privs) {
	p = list_entry(item, struct evtpriv, list);
	if (p->evtchnd == evtchnd) {
	    pthread_mutex_unlock(&privlock);
	    return p;
	}
    }
    pthread_mutex_unlock(&privlock);
    return NULL;
}

static int wait_data(int fd, int waitsecs)
{
    struct timeval tv, start;
    fd_set rd;
    int rc;

    gettimeofday(&start,NULL);
    for (;;) {
	FD_ZERO(&rd);
	FD_SET(fd,&rd);
	tv.tv_sec  = waitsecs;
	tv.tv_usec = 0;
	rc = select(fd+1, &rd, NULL, NULL, &tv);
	if (rc >= 0)
	    break;
	if (errno != EINTR)
	    break;
	/* got signal, restart? */
	gettimeofday(&tv,NULL);
	if (tv.tv_sec - start.tv_sec > waitsecs) {
	    rc = 0;  /* tiemout */
	    break;
	}
    }
    return rc;
}

static int evtchn_ioctl(struct evtpriv *p, int32_t request, void *data,
			int no_reply)
{
    struct evtchn_ioctl_msg req;
    struct evtchn_ioctl_msg rsp;
    int size = _IOC_SIZE(request);
    int rc, count = 0;

    if (size > sizeof(req.data)) {
	fprintf(stderr, "%s: data too big: %d/%zd\n",
		__FUNCTION__, size, sizeof(req.data));
	return -1;
    }

    pthread_mutex_lock(&p->chnlock);
    memset(&req, 0, sizeof(req));
    memcpy(req.data, data, size);
    req.ioctl = request;
    rc = write(p->evtchnd, &req, sizeof(req));
    if (rc != sizeof(req)) {
	fprintf(stderr, "%s: write: %d/%zd (%s)\n", __FUNCTION__,
		rc, sizeof(req), strerror(errno));
	pthread_mutex_unlock(&p->chnlock);
	return -1;
    }
    if (no_reply) {
	pthread_mutex_unlock(&p->chnlock);
	return 0;
    }

retry:
    if (!wait_data(p->evtchnd, 5)) {
	fprintf(stderr, "%s: no reply within 5 sec\n", __FUNCTION__);
	return -1;
    }
    rc = read(p->evtchnd, &rsp, sizeof(rsp));
    if (rc != sizeof(rsp)) {
	if (-1 == rc && (EAGAIN == errno || EINTR == errno) && count++ < 5) {
	    fprintf(stderr, "%s: read: retrying ... (%s)\n",
		    __FUNCTION__, strerror(errno));
	    goto retry;
	}
	fprintf(stderr, "%s: read: %d/%zd (%s)\n", __FUNCTION__,
		rc, sizeof(rsp),
		(-1 == rc) ? strerror(errno) : "partial read");
	pthread_mutex_unlock(&p->chnlock);
	return -1;
    }
    if (req.ioctl != rsp.ioctl) {
	if (rsp.ioctl == EVTCHND_NOTIFY) {
	    if (libxc_trace)
		fprintf(stderr, "%s: got async notify, push back\n", __FUNCTION__);
	    write(p->evtchnd, &rsp, sizeof(rsp));
	} else if (libxc_fixme) {
	    fprintf(stderr, "%s: FIXME: unexpected reply\n", __FUNCTION__);
	}
	goto retry;
    }

    memcpy(data, rsp.data, size);
    if (-1 == rsp.retval)
	errno = rsp.error;
    pthread_mutex_unlock(&p->chnlock);
    return rsp.retval;
}

static int connect_tcp(void)
{
    struct sockaddr_in in;
    int sock, opt = 1;
    
    if (-1 == (sock = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP))) {
	perror("socket(tcp)");
	return -1;
    }
    setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt));

    in.sin_family      = AF_INET;
    in.sin_port        = htons(EVTCHND_PORT);
    in.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    if (-1 == connect(sock, (struct sockaddr*) &in, sizeof(in))) {
	perror("connect(tcp)");
	return -1;
    }

    return sock;
}

static int connect_unix(void)
{
    struct sockaddr_un un;
    int sock;
    
    if (-1 == (sock = socket(PF_UNIX, SOCK_STREAM, 0))) {
	perror("socket(unix)");
	return -1;
    }

    un.sun_family      = AF_UNIX;
    strncpy(un.sun_path, EVTCHND_PATH, sizeof(un.sun_path));
    if (-1 == connect(sock, (struct sockaddr*) &un, sizeof(un))) {
	perror("connect(unix)");
	return -1;
    }

    return sock;
}

/* ------------------------------------------------------------------ */

int xc_evtchn_close(int handle)
{
    struct evtpriv *p = getpriv(handle);

    if (!p)
	return -1;
    if (p->evtchnd > 0)
	close(p->evtchnd);
    pthread_mutex_lock(&privlock);
    list_del(&p->list);
    pthread_mutex_unlock(&privlock);
    free(p);
    return 0;
}

int xc_evtchn_open(void)
{
    struct evtpriv *p;

    p = malloc(sizeof(*p));
    if (NULL == p)
	goto err;
    memset(p,0,sizeof(*p));
    pthread_mutex_lock(&privlock);
    list_add_tail(&p->list, &privs);
    pthread_mutex_unlock(&privlock);

    p->evtchnd = -1;
    if (-1 == p->evtchnd)
	p->evtchnd = connect_unix();
    if (-1 == p->evtchnd)
	p->evtchnd = connect_tcp();
    if (-1 == p->evtchnd)
	goto err;

    if (libxc_trace)
	fprintf(stderr, "libxc: %s -> handle %d\n",
		__FUNCTION__, p->evtchnd);
    return p->evtchnd;

err:
    if (libxc_trace)
	fprintf(stderr, "libxc: %s -> handle %d\n",
		__FUNCTION__, -1);
    if (p)
	xc_evtchn_close(p->evtchnd);
    return -1;
}

int xc_evtchn_fd(int handle)
{
    struct evtpriv *p = getpriv(handle);
    return p->evtchnd;
}

evtchn_port_or_error_t xc_evtchn_bind_unbound_port(int handle, int domid)
{
    struct evtpriv *p = getpriv(handle);
    struct ioctl_evtchn_bind_unbound_port io = {
	.remote_domain = domid,
    };
    int rc = evtchn_ioctl(p, IOCTL_EVTCHN_BIND_UNBOUND_PORT, &io, 0);

    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, domid %d) -> %d\n",
		__FUNCTION__, handle, domid, rc);
    return rc;
}

evtchn_port_or_error_t xc_evtchn_bind_interdomain(int handle, int domid,
					 evtchn_port_t remote_port)
{
    struct evtpriv *p = getpriv(handle);
    struct ioctl_evtchn_bind_interdomain io = {
	.remote_domain = domid,
	.remote_port   = remote_port,
    };
    int rc = evtchn_ioctl(p, IOCTL_EVTCHN_BIND_INTERDOMAIN, &io, 0);

    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, domid %d, rport %d) -> %d\n",
		__FUNCTION__, handle, domid, remote_port, rc);
    return rc;
}

evtchn_port_or_error_t xc_evtchn_bind_virq(int handle, unsigned int virq)
{
    struct evtpriv *p = getpriv(handle);
    struct ioctl_evtchn_bind_virq io = {
	.virq = virq,
    };
    int rc = evtchn_ioctl(p, IOCTL_EVTCHN_BIND_VIRQ, &io, 0);

    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, virq %d) -> %d\n",
		__FUNCTION__, handle, virq, rc);
    return rc;
}

int xc_evtchn_unbind(int handle, evtchn_port_t port)
{
    struct evtpriv *p = getpriv(handle);
    struct ioctl_evtchn_unbind io = {
	.port = port,
    };
    int rc = evtchn_ioctl(p, IOCTL_EVTCHN_UNBIND, &io, 0);

    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, port %d) -> %d\n",
		__FUNCTION__, handle, port, rc);
    return rc;
}

int xc_evtchn_notify(int handle, evtchn_port_t port)
{
    struct evtpriv *p = getpriv(handle);
    struct ioctl_evtchn_notify io = {
	.port = port,
    };
    int rc;

    rc = evtchn_ioctl(p, IOCTL_EVTCHN_NOTIFY, &io, 1);
    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, port %d) -> %d\n",
		__FUNCTION__, handle, port, rc);
    return rc;
}

evtchn_port_or_error_t xc_evtchn_pending(int handle)
{
    struct evtpriv *p = getpriv(handle);
    struct evtchn_ioctl_msg msg;
    struct evtchnd_port *n = (void*)(&msg.data);
    int rc, count = 0;

    pthread_mutex_lock(&p->chnlock);
again:
    rc = read(p->evtchnd, &msg, sizeof(msg));
    if (rc != sizeof(msg)) {
	if (-1 == rc && (EAGAIN == errno || EINTR == errno) && count++ < 5)
	    goto again;
	fprintf(stderr, "libxc: %s: read error (%d/%zd) %s\n",
		__FUNCTION__, rc, sizeof(msg), strerror(errno));
	pthread_mutex_unlock(&p->chnlock);
	return -1;
    }
    pthread_mutex_unlock(&p->chnlock);
    
    if (EVTCHND_NOTIFY != msg.ioctl) {
	fprintf(stderr, "libxc: %s: Huh? msg != notify?\n", __FUNCTION__);
	return -1;
    }
    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d) -> %d\n",
		__FUNCTION__, handle, n->port);
    return n->port;
}

int xc_evtchn_unmask(int handle, evtchn_port_t port)
{
#if 0
    struct evtpriv *p = getpriv(handle);
    struct evtchn_ioctl_msg msg;
    struct evtchnd_port *n = (void*)(&msg.data);

    pthread_mutex_lock(&p->chnlock);
    memset(&msg, 0, sizeof(msg));
    n->port = port;
    msg.ioctl = EVTCHND_UNMASK;
    write(p->evtchnd, &msg, sizeof(msg));
    pthread_mutex_unlock(&p->chnlock);

    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, port %d)\n",
		__FUNCTION__, handle, port);
#endif
    return 0;
}


/* ----------------------------------------------------------- */

int xenner_evtchnd_domid(int handle, int domid)
{
    struct evtpriv *p = getpriv(handle);
    struct evtchnd_domid io = {
	.domid = domid,
    };
    int rc = evtchn_ioctl(p, EVTCHND_DOMID, &io, 0);

    if (libxc_trace)
	fprintf(stderr, "libxc: %s(handle %d, domid %d) -> %d\n",
		__FUNCTION__, handle, domid, rc);
    return rc;
}
