/* $Id: avif.C,v 1.27 2005/10/20 07:03:29 dm Exp $ */

/*
 *
 * Copyright (C) 2003 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */


#include "asmtpd.h"
#include "rawnet.h"

ihash<const uid_t, avcount, &avcount::uid, &avcount::link> avctab;

avcount::avcount (uid_t u)
  : uid (u), num (0), release_lock (false)
{
  avctab.insert (this);
}

avcount::~avcount ()
{
  avctab.remove (this);
}

void
avcount::release ()
{
  assert (num > 0);
  num--;
  if (release_lock)
    return;
  release_lock = true;

  while (num < int (opt->avenger_max_per_user) && !waiters.empty ())
    (*waiters.pop_front ()) ();

  if (num)
    release_lock = false;
  else
    delete this;
}

bool
avcount::acquire ()
{
  if (num >= int (opt->avenger_max_per_user))
    return false;
  num++;
  return true;
}

avcount *
avcount::get (uid_t u)
{
  if (avcount *avc = avctab[u])
    return avc;
  return New avcount (u);
}

avif::avif (const smtpd *s, str n, avcount *avc, pid_t p, int fd, cb_t c)
  : cb (c), smtp (s), pid (p), aio (aios::alloc (fd)), name (n), avc (avc)
{
  if (opt->debug_avenger)
    aio->setdebug (strbuf ("%s (a%d)", name.cstr (), fd));
}

void
avif::init ()
{
  chldcb (pid, wrap (this, &avif::reap));
  aio->readline (wrap (this, &avif::input));
}

avif::~avif ()
{
  aio->readcancel ();
  aio->abort ();
  if (pid > 0) {
    chldcb (pid, NULL);
    if (killpg (pid, SIGKILL) < 0)
      kill (pid, SIGKILL);
  }
  if (avc)
    avc->release ();
  while (result *rp = reslist.first)
    delres (rp);
}

void
avif::input (str line, int err)
{
  if (!line || strlen (line) != line.len ()) {
    (*cb) (NEXT, NULL);		// By default accept mail
    delete this;
    return;
  }

  static rxx spfrx ("^spf(1)?\\s+(\\w+)\\s+(.*)$");
  static rxx dnsarx ("^dns-a\\s+(\\w+)\\s+(\\S+)$");
  static rxx dnsptrrx ("^dns-ptr\\s+(\\w+)\\s+(\\S+)$");
  static rxx dnsmxrx ("^dns-mx\\s+(\\w+)\\s+(\\S+)$");
  static rxx dnstxtrx ("^dns-txt\\s+(\\w+)\\s+(\\S+)$");
  static rxx netpathrx ("^netpath\\s+(\\w+)\\s+(\\S+)(\\s+(-?\\d+))?$");
  static rxx retrx ("^return\\s+(([245]\\d\\d)([ -]).*)$");
  static rxx retcontrx ("^(([245]\\d\\d)([ -]).*)$");
  static rxx redirrx ("^redirect\\s+(\\S+)$");
  static rxx bodytestrx ("^bodytest\\s+(\\S.*)$");

  if (retcode) {
    if (!retcontrx.match (line) || retcode != retcontrx[2]) {
      badinput (line);
      return;
    }
    retbuf << retcontrx[1] << "\r\n";
    if (retcontrx[3] != "-") {
      (*cb) (DONE, retbuf);
      delete this;
      return;
    }
  }
  else if (!line.len ())
    ;
  else if (line[0] == '.')
    newres ()->res = line << "\n";
  else if (spfrx.match (line)) {
    str from = smtp->get_from ();
    if (!from || !from.len ())
      from = smtp->get_helo ();
    spf_t *spf = New spf_t (smtp->get_addr (), from);
    spf->spfrec = spfrx[3];
    spf->helo = smtp->get_helo ();
    spf->ptr_cache = smtp->ptr_cache;
    result *rp = newres ();
    spf->cb = wrap (this, &avif::spf_cb, spfrx[2], rp, spfrx[1]);
    rp->abortcb = wrap (spf_cancel, spf);
    spf->init ();
  }
  else if (dnsarx.match (line)) {
    result *rp = newres ();
    if (dnsreq *rqp
	= dns_hostbyname (dnsarx[2],
			  wrap (this, &avif::dns_a_cb, dnsarx[1], rp),
			  false, false))
      rp->abortcb = wrap (dnsreq_cancel, rqp);
  }
  else if (dnsptrrx.match (line)) {
    str var = dnsptrrx[1];
    str name = dnsptrrx[2];
    result *rp = newres ();
    in_addr a;
    int r = inet_aton (name, &a);
    if (r == 0)
      rp->res = var << "=\n";
    else if (r < 0)
      rp->res = "";
    else if (dnsreq *rqp
	     = dns_hostbyaddr (a, wrap (this, &avif::dns_ptr_cb, var, rp)))
      rp->abortcb = wrap (dnsreq_cancel, rqp);
  }
  else if (dnsmxrx.match (line)) {
    result *rp = newres ();
    if (dnsreq *rqp
	= dns_mxbyname (dnsmxrx[2],
			wrap (this, &avif::dns_mx_cb, dnsmxrx[1], rp),
			false))
      rp->abortcb = wrap (dnsreq_cancel, rqp);
  }
  else if (dnstxtrx.match (line)) {
    result *rp = newres ();
    if (dnsreq *rqp
	= dns_txtbyname (dnstxtrx[2],
			wrap (this, &avif::dns_txt_cb, dnstxtrx[1], rp),
			false))
      rp->abortcb = wrap (dnsreq_cancel, rqp);
  }
  else if (netpathrx.match (line)) {
    result *rp = newres ();
    int hops = 0;
    if (str h = netpathrx[4])
      convertint (h, &hops);
    if (dnsreq *rqp = dns_hostbyname (netpathrx[2],
				     wrap (this, &avif::netpath_cb1,
					   netpathrx[1], hops, rp),
				     true, true))
      rp->abortcb = wrap (dnsreq_cancel, rqp);
  }
  else if (retrx.match (line)) {
    str code (retrx[2]);
    if (retcode && retcode != code) {
      badinput (line);
      return;
    }
    retcode = code;
    retbuf << retrx[1] << "\r\n";
    if (retrx[3] != "-") {
      (*cb) (DONE, retbuf);
      delete this;
      return;
    }
  }
  else if (redirrx.match (line)) {
    (*cb) (REDIR, redirrx[1]);
    delete this;
    return;
  }
  else if (bodytestrx.match (line)) {
    (*cb) (BODY, bodytestrx[1]);
    delete this;
    return;
  }
  else {
    badinput (line);
    return;
  }

  aio->readline (wrap (this, &avif::input));
  maybe_reply ();
}

str
safestring (str msg)
{
  /* Who knows what a bad user might accomplish by sending weird
   * control characters... */
  strbuf sb;
  for (u_int i = 0; i < msg.len (); i++) {
    u_char c = msg[i];
    if (c == 0x7f)
      sb.tosuio ()->copy ("^?", 2);
    else if (c >= ' ')
      sb.tosuio ()->copy (&c, 1);
    else {
      sb.tosuio ()->copy ("^", 1);
      c = c + '@';
      sb.tosuio ()->copy (&c, 1);
    }
  }
  return sb;
}
void
avif::badinput (str line)
{
  warn ("bad input from %s's %s: ", name.cstr (), AVENGER)
    << safestring (line) << "\n";
  (*cb) (NEXT, NULL);
  delete this;
  return;
}

void
avif::spf_cb (str var, result *rp, bool one, spf_t *spf)
{
  if (one)
    rp->res = var << "=" << spf1_print (spf->result) << "\n";
  else
    rp->res = var << "=" << spf_print (spf->result) << "\n";
  maybe_reply ();
}

void
avif::dns_a_cb (str var, result *rp, ptr<hostent> h, int err)
{
  strbuf sb;
  if (h) {
    char **ap = h->h_addr_list;
    sb << var << "=" << str (inet_ntoa (*(in_addr *) *ap));
    while (*++ap)
      sb << " " << str (inet_ntoa (*(in_addr *) *ap));
    sb << "\n";
  }
  else if (!dns_tmperr (err))
    sb << var << "=\n";

  rp->res = sb;
  maybe_reply ();
}

void
avif::dns_ptr_cb (str var, result *rp, ptr<hostent> h, int err)
{
  strbuf sb;
  if (h) {
    sb << var << "=" << h->h_name;
    for (char **np = h->h_aliases; *np; np++)
      sb << " " << *np;
    sb << "\n";
  }
  else if (!dns_tmperr (err))
    sb << var << "=\n";

  rp->res = sb;
  maybe_reply ();
}

void
avif::dns_mx_cb (str var, result *rp, ptr<mxlist> mxl, int err)
{
  strbuf sb;
  if (mxl) {
    sb << var << "=";
    sb << int (mxl->m_mxes[0].pref) << ":" << mxl->m_mxes[0].name;
    for (u_int i = 1; i < mxl->m_nmx; i++)
      sb << " " << int (mxl->m_mxes[i].pref) << ":" << mxl->m_mxes[i].name;
    sb << "\n";
  }
  else if (!dns_tmperr (err))
    sb << var << "=\n";

  rp->res = sb;
  maybe_reply ();
}

void
avif::dns_txt_cb (str var, result *rp, ptr<txtlist> t, int err)
{
  strbuf sb;
  if (t) {
    /* XXX - This is bad if there are multiple TXT records.  moreover,
     * if a TXT record contains a newline, this will cause maybe_reply
     * to fail, and thus the variable will not get set. */
    sb << var << "=" << t->t_txts[0] << "\n";
  }
  else if (!dns_tmperr (err))
    sb << var << "=\n";
  rp->res = sb;
  maybe_reply ();
}

void
avif::netpath_cb1 (str var, int hops, result *rp, ptr<hostent> h, int err)
{
  rp->abortcb = NULL;		// not needed
  if (h) {
    sockaddr_in sin;
    sin.sin_family = AF_INET;
    sin.sin_port = htons (0);
    sin.sin_addr = *(in_addr *) h->h_addr;
    if (traceroute *trp = netpath (&sin, hops,
				   wrap (this, &avif::netpath_cb2, var, rp)))
      rp->abortcb = wrap (netpath_cancel, trp);
  }
  else {
    rp->res = "";
    maybe_reply ();
  }
}
void
avif::netpath_cb2 (str var, result *rp, int nhops, in_addr *av, int an)
{
  strbuf sb;
  if (an > 0) {
    sb << var << "=" << nhops;
    for (int i = 0; i < an; i++)
      sb << " " << inet_ntoa (av[i]);
    sb << "\n";
  }
  rp->res = sb;
  maybe_reply ();
}

void
avif::maybe_reply ()
{
  result *rp;
  while ((rp = reslist.first) && rp->res) {
    /* Sheer paranoia--what if some weird caracters come back in DNS
     * requests (or something) and somehow don't get filtered by the
     * resolver. */
    if (strchr (rp->res, '\n') == rp->res.cstr () + rp->res.len () - 1
	&& !memchr (rp->res, '\0', rp->res.len ()))
      aio << rp->res;
    else if (rp->res.len ())
      warn << "user " << name << " newline should be at end of variable\n";
    delres (rp);
  }
}

void
avif::chldinit (struct passwd *pw, int fd, bool sys, str ext)
{
  if (opt->avenger_timeout)
    alarm (opt->avenger_timeout);

  bool root = getuid () <= 0;
  str avdir;
  if (sys)
    avdir = pw->pw_dir;
  else
    avdir = strbuf () << pw->pw_dir << "/.avenger";

#ifdef HAVE_SETEUID
  if (!sys) {
    /* quick optimization because setgroups is expensive */
    GETGROUPS_T gid = pw->pw_gid;
    setgid (gid);
    if (root)
      seteuid (pw->pw_uid);
    struct stat sb;
    if (!sys && lstat (avdir, &sb)) {
      if (smtpd::tmperr (errno)) {
	aout << "return 451 " << avdir << ": " << strerror (errno) << "\n";
	aout->flush ();
      }
      _exit (0);
    }
    if (root)
      seteuid (getuid ());
    if (!S_ISDIR (sb.st_mode) || (sb.st_uid && sb.st_uid != pw->pw_uid)) {
      warn << avdir << " should be directory owned by " << pw->pw_name << "\n";
      _exit (0);
    }
  }
#endif /* HAVE_SETEUID */

  if (sys)
    setgroups (opt->av_groups.size (), opt->av_groups.base ());
  become_user (pw, !sys);

  if (chdir (avdir) < 0) {
    maybe_warn (strbuf ("%s: %m\n", avdir.cstr ()));
    if (smtpd::tmperr (errno)) {
      aout << "return 451 " << avdir << ": " << strerror (errno) << "\n";
      aout->flush ();
      _exit (0);
    }
    if (!sys)
      _exit (0);
  }
}

void
avif::alloc (struct passwd *pw, const smtpd *s, str recip, char mode,
	     avcount *avc, str ext, str avuser, cb_t cb, str extraenv)
{
  if (mode == 's') {
    str path = strbuf () << opt->etcdir << "/" << ext;
    if (access (path, 0) < 0 && errno == ENOENT) {
      (*cb) (NEXT, NULL);
      return;
    }
  }

  int fds[2];
  if (socketpair (AF_UNIX, SOCK_STREAM, 0, fds) < 0) {
    (*cb) (DONE, strbuf ("451 %m\r\n"));
    return;
  }

  close_on_exec (fds[0]);
  if (fds[1] > 1)
    close_on_exec (fds[1]);

  const char *av[] = { path_avenger, NULL, NULL, NULL };
  str modestr;
  if (mode) {
    av[1] = modestr = strbuf ("-%c", mode);
    av[2] = ext;
  }
  else
    av[1] = ext;

  vec<str> senv;
  s->envinit (&senv, pw);

  if (!strncmp (senv[0], "PWD=", 4))
    senv.pop_front ();
  if (mode == 's')
    senv.push_back (strbuf ("PWD=%s", pw->pw_dir));
  else
    senv.push_back (strbuf ("PWD=%s/.avenger", pw->pw_dir));
  senv.push_back (strbuf ("RECIPIENT=") << recip);
  senv.push_back (strbuf ("RECIPIENT_HOST=")
		  << mytolower (extract_domain (recip)));
  senv.push_back (strbuf ("RECIPIENT_LOCAL=")
		  << mytolower (extract_local (recip)));
  if (ext && mode != 's')
    senv.push_back (strbuf ("EXT=") << ext);
  if (extraenv)
    senv.push_back (extraenv);
  if (avuser)
    senv.push_back (strbuf () << "AVUSER=" << avuser);

  vec<const char *> env;
  env.reserve (senv.size () + 1);
  for (const str *sp = senv.base (); sp < senv.lim (); sp++)
    env.push_back (sp->cstr ());
  env.push_back (NULL);

  pid_t pid = aspawn (path_avenger, av, fds[1], fds[1], errfd,
		      wrap (&chldinit, pw, fds[1], mode == 's', ext),
		      const_cast<char **> (env.base ()));
  if (pid <= 0) {
    (*cb) (DONE, strbuf ("451 %m\r\n"));
    return;
  }

  str name;
  if (mode == 's' && ext)
    name = strbuf ("%c", opt->separator ? opt->separator : ' ') << ext;
  else if (ext)
    name = strbuf ("%s%c", pw->pw_name, opt->separator) << ext;
  else
    name = pw->pw_name;

  close (fds[1]);
  (New avif (s, name, avc, pid, fds[0], cb))->init ();
}

void
avif::reap (int status)
{
  pid = -1;
  if (status)
    warn << AVENGER " for " << name << " exited with "
	 << exitstr (status) << "\n";
}
