#include <Python.h>

#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>

static PyObject *socket_recvmsg(PyObject *self, PyObject *args)
{
  int fd, msglen, cmsglen, flags;
  PyObject *msgbuf = NULL, *cmsgbuf = NULL, *ret = NULL, *cmsglist = NULL;
  struct iovec iov = { 0, 0 };
  struct msghdr msg = { NULL, 0, &iov, 1, 0, 0, 0 };
  ssize_t outlen;

  if (!PyArg_ParseTuple(args, "iiii:recvmsg", &fd, &msglen, &cmsglen, &flags))
    return NULL;

  if (msglen < 0 || cmsglen < 0)
  {
    PyErr_SetString(PyExc_ValueError, "negative buffersize in recvmsg");
    return NULL;
  }

  msgbuf = PyString_FromStringAndSize(NULL, msglen);
  cmsgbuf = PyString_FromStringAndSize(NULL, cmsglen);

  iov.iov_base = PyString_AS_STRING(msgbuf);
  iov.iov_len = msglen;
  msg.msg_control = PyString_AS_STRING(cmsgbuf);
  msg.msg_controllen = cmsglen;

  Py_BEGIN_ALLOW_THREADS;
  outlen = recvmsg(fd, &msg, flags);
  Py_END_ALLOW_THREADS;

  if (outlen < 0)
  {
    PyErr_SetFromErrno(PyExc_OSError);
    goto error;
  }

  if (outlen != msglen)
    if (_PyString_Resize(&msgbuf, outlen))
      goto error;

  if (msg.msg_controllen != cmsglen)
    if (_PyString_Resize(&cmsgbuf, msg.msg_controllen))
      goto error;

  cmsglist = PyList_New(0);
  if (!cmsglist)
    goto error;

  for (struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg))
  {
    int len = cmsg->cmsg_len - CMSG_LEN(0);
    void *data = CMSG_DATA(cmsg);
    PyObject *r = PyTuple_New(3);
    if (!r)
      goto error;
    PyTuple_SET_ITEM(r, 0, PyInt_FromLong(cmsg->cmsg_level));
    PyTuple_SET_ITEM(r, 1, PyInt_FromLong(cmsg->cmsg_type));
    PyObject *s = PyString_FromStringAndSize(data, len);
    if (!s)
    {
      Py_DECREF(r);
      goto error;
    }
    PyTuple_SET_ITEM(r, 2, s);
    PyList_Append(cmsglist, r);
    Py_DECREF(r);
  }

  ret = PyTuple_New(3);
  if (!ret)
    goto error;

  PyTuple_SET_ITEM(ret, 0, msgbuf);
  PyTuple_SET_ITEM(ret, 1, cmsglist);
  PyTuple_SET_ITEM(ret, 2, PyInt_FromLong(msg.msg_flags));
  Py_XDECREF(cmsgbuf);

  return ret;

error:
  Py_XDECREF(msgbuf);
  Py_XDECREF(cmsgbuf);
  Py_XDECREF(ret);
  Py_XDECREF(cmsglist);
  return NULL;
}

PyDoc_STRVAR(recvmsg_doc,
"recvmsg(buffersize, cmsglen) -> (data, cmsglist)");

static PyObject *socket_sendmsg(PyObject *self, PyObject *args)
{
  int fd, flags;
  char *msgbuf;
  int msglen;
  PyObject *cmsglist = NULL, *ret = NULL;
  void *cmsg = NULL;
  size_t cmsglen = 0;
  struct iovec iov = { 0, 0 };
  struct msghdr msg = { NULL, 0, &iov, 1, 0, 0, 0 };
  ssize_t outlen;

  if (!PyArg_ParseTuple(args, "is#O!i:sendmsg", &fd, &msgbuf, &msglen, &PyList_Type, &cmsglist, &flags))
    return NULL;

  iov.iov_base = msgbuf;
  iov.iov_len = msglen;

  for (int i = 0; i < PyList_GET_SIZE(cmsglist); i++)
  {
    int s0, s1, s2len;
    char *s2;

    if (!PyArg_ParseTuple(PyList_GET_ITEM(cmsglist, i), "iis#:sendmsg", &s0, &s1, &s2, &s2len))
      goto error;

    void *temp = PyMem_Realloc(cmsg, cmsglen + CMSG_SPACE(s2len));
    if (!temp)
    {
      PyErr_NoMemory();
      goto error;
    }
    cmsg = temp;

    struct cmsghdr *c = temp + cmsglen;
    c->cmsg_len = CMSG_LEN(s2len);
    c->cmsg_level = s0;
    c->cmsg_type = s1;
    memcpy(CMSG_DATA(c), s2, s2len);

    cmsglen += CMSG_SPACE(s2len);
  }

  msg.msg_control = cmsg;
  msg.msg_controllen = cmsglen;

  Py_BEGIN_ALLOW_THREADS;
  outlen = sendmsg(fd, &msg, flags);
  Py_END_ALLOW_THREADS;

  if (outlen < 0)
  {
    PyErr_SetFromErrno(PyExc_OSError);
    goto error;
  }

  ret = PyInt_FromLong(outlen);

error:
  PyMem_Free(cmsg);
  return ret;
}

PyDoc_STRVAR(sendmsg_doc,
"sendmsg(data, cmsglist) -> count");

static PyMethodDef socket_methods[] =
{
  {
    "recvmsg", socket_recvmsg,
    METH_VARARGS, recvmsg_doc
  },
  {
    "sendmsg", socket_sendmsg,
    METH_VARARGS, sendmsg_doc
  },
};

PyMODINIT_FUNC init_socket(void)
{
  PyObject *m;

  m = Py_InitModule("_socket", socket_methods);
  if (!m)
    return;

  PyModule_AddIntConstant(m, "SO_PASSCRED", SO_PASSCRED);
  PyModule_AddIntConstant(m, "SCM_RIGHTS", SCM_RIGHTS);
  PyModule_AddIntConstant(m, "SCM_CREDENTIALS", SCM_CREDENTIALS);
#if 0
  PyModule_AddIntConstant(m, "SCM_SECURITY", SCM_SECURITY);
#endif
}

// vim: sw=2 expandtab
