/*
   Copyright (C) 2000-2002  Ulric Eriksson <ulric@siag.nu>

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place - Suite 330, Boston,
   MA 02111-1307, USA.
*/

#include "config.h"

#include <stdio.h>
#include <ctype.h>
#include <errno.h>
#include <netdb.h>
#include <sys/types.h>
#include <assert.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <signal.h>
#include <stdlib.h>
#include <stdarg.h>
#include <ctype.h>
#include <time.h>
#ifdef TIME_WITH_SYS_TIME
#include <sys/time.h>
#endif
#include <sys/wait.h>
#include <sys/resource.h>
#include <sys/socket.h>
#include <sys/stat.h>
#ifdef HAVE_POLL
#include <sys/poll.h>
#endif
#include <fcntl.h>
#include <unistd.h>
#include <syslog.h>
#include <string.h>
#include <pwd.h>

#define BUFFER_MAX 	(32*1024)

#define CLIENTS_MAX	2048	/* max clients */
#define CONNECTIONS_MAX	256	/* max simultaneous connections */
#define TIMEOUT		5	/* default timeout for non reachable hosts */
#define BLACKLIST_TIME	30	/* how long to shun a server that is down */
#define KEEP_MAX	100	/* how much to keep from the URI */

typedef struct {
	int downfd, upfd;
	unsigned char *downb, *downbptr, *upb, *upbptr;
	int downn, upn;
	int clt;
	int index;		/* server index */
} connection;

typedef struct {
	int status;		/* last failed connection attempt */
	int port;
	struct in_addr addr;
	int c;			/* connections */
	int maxc;		/* max connections, soft limit */
	int hard;		/* max connections, hard limit */
	unsigned long long sx, rx;	/* bytes sent, received */
} server;

typedef struct {
	time_t last;		/* last time this client made a connection */
	struct in_addr addr;	/* of client */
	int cno;		/* server used last time */
	long connects;
	long long csx, crx;
} client;

static int nservers;		/* number of servers */
static int current;		/* current server */
static client *clients;
static server *servers;
static int emerg_server = -1;	/* server of last resort */
static connection *conns;

static int debuglevel;
static int asciidump;
static int foreground;
static int loopflag;

static int clients_max = CLIENTS_MAX;
static int connections_max = CONNECTIONS_MAX;
static int timeout = TIMEOUT;
static int blacklist_time = BLACKLIST_TIME;
static int roundrobin = 0;
static int hash = 0;
static int stubborn = 0;
static int nblock = 1;
static int delayed_forward = 0;
static int do_stats = 0;
static int do_restart_log = 0;
static int use_poll = 0;

static int port;

static char *logfile = NULL;
static FILE *logfp = NULL;
static struct sockaddr_in logserver;
static int logsock = -1;
static char *pidfile = NULL;
static FILE *pidfp = NULL;
static char *webfile = NULL;
static char *listenport = NULL;
static char *ctrlport = NULL;
static char *e_server = NULL;
static char *jail = NULL;
static char *user = NULL;

static void debug(char *fmt, ...)
{

	time_t now;
	struct tm *nowtm;
	char nowstr[80];
	char b[4096];
	va_list ap;
	va_start(ap, fmt);
	vsnprintf(b, sizeof b, fmt, ap);
	now=time(NULL);
	nowtm = localtime(&now);
	strftime(nowstr, sizeof(nowstr), "%Y-%m-%d %H:%M:%S", nowtm);
	if (foreground) {
		fprintf(stderr, "%s: %s\n", nowstr, b);
	} else {
		openlog("pen", LOG_CONS, LOG_USER);
		syslog(LOG_DEBUG, "%s\n", b);
		closelog();
	}
	va_end(ap);
}

static void error(char *fmt, ...)
{
	char b[4096];
	va_list ap;
	va_start(ap, fmt);
	vsnprintf(b, sizeof b, fmt, ap);
	if (foreground) {
		fprintf(stderr, "%s\n", b);
	} else {
		openlog("pen", LOG_CONS, LOG_USER);
		syslog(LOG_ERR, "%s\n", b);
		closelog();
	}
	va_end(ap);
	exit(1);
}

static void *pen_malloc(size_t n)
{
	void *q = malloc(n);
	if (!q) error("Can't malloc %ld bytes", (long)n);
	return q;
}

static void *pen_calloc(size_t n, size_t s)
{
	void *q = calloc(n, s);
	if (!q) error("Can't calloc %ld bytes", (long)n*s);
	return q;
}

static char *pen_strdup(char *p)
{
	char *b = pen_malloc(strlen(p)+1);
	return strcpy(b, p);
}

static void webstats(void)
{
	FILE *fp;
	int i;
	time_t now;
	struct tm *nowtm;
	char nowstr[80];

	fp = fopen(webfile, "w");
	if (fp == NULL) return;
	now=time(NULL);
	nowtm = localtime(&now);
	strftime(nowstr, sizeof(nowstr), "%Y-%m-%d %H:%M:%S", nowtm);
	fprintf(fp,
		"<html>\n"
		"<head>\n"
		"<title>Pen status page</title>\n"
		"</head>\n"
		"<body bgcolor=\"#ffffff\">"
		"<h1>Pen status page</h1>\n");
	fprintf(fp,
		"Time %s, %d servers, %d current<p>\n",
		nowstr, nservers, current);
	fprintf(fp,
		"<table bgcolor=\"#c0c0c0\">\n"
		"<tr>\n"
		"<td bgcolor=\"#80f080\">server</td>\n"
		"<td bgcolor=\"#80f080\">address</td>\n"
		"<td bgcolor=\"#80f080\">status</td>\n"
		"<td bgcolor=\"#80f080\">port</td>\n"
		"<td bgcolor=\"#80f080\">connections</td>\n"
		"<td bgcolor=\"#80f080\">max soft</td>\n"
		"<td bgcolor=\"#80f080\">max hard</td>\n"
		"<td bgcolor=\"#80f080\">sent</td>\n"
		"<td bgcolor=\"#80f080\">received</td>\n"
		"</tr>\n");
	for (i = 0; i < nservers; i++) {
		fprintf(fp,
			"<tr>\n"
			"<td>%d</td>\n"
			"<td>%s</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%llu</td>\n"
			"<td>%llu</td>\n"
			"</tr>\n",
			i, inet_ntoa(servers[i].addr),
			servers[i].status, servers[i].port,
			servers[i].c, servers[i].maxc, servers[i].hard,
			servers[i].sx, servers[i].rx);
	}
	fprintf(fp, "</table>\n");

	fprintf(fp, "<h2>Active clients</h2>");
	fprintf(fp, "Max number of clients: %d<p>", clients_max);
	fprintf(fp,
		"<table bgcolor=\"#c0c0c0\">\n"
		"<tr>\n"
		"<td bgcolor=\"#80f080\">client</td>\n"
		"<td bgcolor=\"#80f080\">address</td>\n"
		"<td bgcolor=\"#80f080\">age(secs)</td>\n"
		"<td bgcolor=\"#80f080\">last server</td>\n"
		"<td bgcolor=\"#80f080\">connects</td>\n"
		"<td bgcolor=\"#80f080\">sent</td>\n"
		"<td bgcolor=\"#80f080\">received</td>\n"
		"</tr>\n");
	for (i = 0; i < clients_max; i++) {
		if (clients[i].last == 0) continue;
		fprintf(fp,
			"<tr>\n"
			"<td>%d</td>\n"
			"<td>%s</td>\n"
			"<td>%ld</td>\n"
			"<td>%d</td>\n"
			"<td>%ld</td>\n"
			"<td>%lld</td>\n"
			"<td>%lld</td>\n"
			"</tr>\n",
			i, inet_ntoa(clients[i].addr),
			(long)(now-clients[i].last), clients[i].cno, clients[i].connects,
			clients[i].csx, clients[i].crx);
	}
	fprintf(fp, "</table>\n");

	fprintf(fp, "<h2>Active connections</h2>");
	fprintf(fp, "Max number of connections: %d<p>", connections_max);
	fprintf(fp,
		"<table bgcolor=\"#c0c0c0\">\n"
		"<tr>\n"
		"<td bgcolor=\"#80f080\">connection</td>\n"
		"<td bgcolor=\"#80f080\">downfd</td>\n"
		"<td bgcolor=\"#80f080\">upfd</td>\n"
		"<td bgcolor=\"#80f080\">pending data down</td>\n"
		"<td bgcolor=\"#80f080\">pending data up</td>\n"
		"<td bgcolor=\"#80f080\">client</td>\n"
		"<td bgcolor=\"#80f080\">server</td>\n"
		"</tr>\n");
	for (i = 0; i < connections_max; i++) {
		if (conns[i].downfd == -1) continue;
		fprintf(fp,
			"<tr>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"<td>%d</td>\n"
			"</tr>\n",
			i, conns[i].downfd, conns[i].upfd,
			conns[i].downn, conns[i].upn,
			conns[i].clt, conns[i].index);
	}
	fprintf(fp, "</table>\n");
	fprintf(fp,
		"</body>\n"
		"</html>\n");
	fclose(fp);
}

static void textstats(void)
{
	int i;
	time_t now;
	struct tm *nowtm;
	char nowstr[80];

	now=time(NULL);
	nowtm = localtime(&now);
	strftime(nowstr, sizeof(nowstr), "%Y-%m-%d %H:%M:%S", nowtm);

	debug("Time %s, %d servers, %d current",
		nowstr, nservers, current);
	for (i = 0; i < nservers; i++) {
		debug("Server %d status:\n"
			"address %s\n"
			"%d\n"
			"port %d\n"
			"%d connections (%d soft, %d hard)\n"
			"%llu sent, %llu received\n",
			i, inet_ntoa(servers[i].addr),
			servers[i].status, servers[i].port,
			servers[i].c, servers[i].maxc, servers[i].hard,
			servers[i].sx, servers[i].rx);
	}
	debug("Max number of clients: %d", clients_max);
	debug("Active clients:");
	for (i = 0; i < clients_max; i++) {
		if (clients[i].last == 0) continue;
		debug("Client %d status:\n"
			"address %s\n"
			"last used %ld\n"
			"last server %d\n",
			"connects  %ld\n",
			"sent  %llu\n",
			"received  %llu\n",
			i, inet_ntoa(clients[i].addr),
			(long)(now-clients[i].last), clients[i].cno, clients[i].connects,
			clients[i].csx, clients[i].crx);
	}
	debug("Max number of connections: %d", connections_max);
	debug("Active connections:");
	for (i = 0; i < connections_max; i++) {
		if (conns[i].downfd == -1) continue;
		debug("Connection %d status:\n"
			"downfd = %d, upfd = %d\n"
			"pending data %d down, %d up\n"
			"client %d, server %d\n",
			i, conns[i].downfd, conns[i].upfd,
			conns[i].downn, conns[i].upn,
			conns[i].clt, conns[i].index);
	}
}


static void stats(int dummy)
{
	do_stats=1;
	signal(SIGUSR1, stats);
}

static void restart_log(int dummy)
{
	do_restart_log=1;
	signal(SIGHUP, restart_log);
}

static void quit(int dummy)
{
	loopflag = 0;
}

/* Return index of known client, otherwise -1 */
static int lookup_client(struct in_addr cli)
{
	int i;
	unsigned long ad = cli.s_addr;

	for (i = 0; i < clients_max; i++) {
		if (clients[i].addr.s_addr == ad) break;
	}
	if (i == clients_max) i = -1;
	if (debuglevel) debug("Client %s has index %d", inet_ntoa(cli), i);
	return i;
}

/* Store client and return index */
static int store_client(int clino, struct in_addr cli, int ch)
{
	int i;
	int empty = -1;		/* first empty slot */
	int oldest = -1;	/* in case we need to recycle */

	if (clino == -1) {
		for (i = 0; i < clients_max; i++) {
			if (clients[i].addr.s_addr == cli.s_addr) break; /* XXX */
			if (empty != -1) continue;
			if (clients[i].last == 0) {
				empty = i;
				continue;
			}
			if (oldest == -1 || (clients[i].last < clients[oldest].last)) {
				oldest = i;
			}
		}
	
		if (i == clients_max) {
			if (empty != -1) i = empty;
			else i = oldest;
		}
		clients[i].connects = 0;
		clients[i].csx = 0;
		clients[i].crx = 0;
	}
	else 
		i = clino;

	clients[i].last = time(NULL);
	clients[i].addr = cli;
	clients[i].cno = ch;
	clients[i].connects++;

	if (debuglevel) {
		debug("Client %s has index %d and server %d",
			inet_ntoa(cli), i, ch);
	}
	servers[ch].c++;

	return i;
}

static void dump(unsigned char *p, int n)
{
	int i;

	fprintf(stderr, "%d: ", n);
	for (i = 0; i < n; i++) {
		if (asciidump) {
			fprintf(stderr, "%c",
				(isprint(p[i])||isspace(p[i]))?p[i]:'.');
		} else {
			fprintf(stderr, "%02x ", (int)p[i]);
		}
	}
	fprintf(stderr, "\n");
}


static int getport(char *p)
{
	struct servent *s = getservbyname(p, "tcp");
	if (s == NULL) {
		return atoi(p);
	} else {
		return ntohs(s->s_port);
	}
}

static void setipaddress(struct in_addr *a, char *p)
{
	struct hostent *h = gethostbyname(p);
	if (h == NULL) {
		if ((a->s_addr = inet_addr(p)) == -1) {
			error("unknown or invalid address [%s]\n", p);
		}
	} else {
		memcpy(a, h->h_addr, h->h_length);
	}
}

static void setaddress(struct in_addr *a, int *port, char *s,
		int dp, int *maxc, int *hard)
{
	struct hostent *h;
	char address[1024], pno[100];
	int n = sscanf(s, "%999[^:]:%99[^:]:%d:%d", address, pno, maxc, hard);

	if (n > 1) *port = getport(pno);
	else *port = dp;
	if (n < 3) *maxc = 0;
	if (n < 4) *hard = 0;

	if (debuglevel)
		debug("n = %d, address = %s, pno = %d, maxc1 = %d, hard = %d",
			n, address, *port, *maxc, *hard);

	if (!(h = gethostbyname(address))) {
		if ((a->s_addr = inet_addr(address)) == -1) {
			error("unknown or invalid address [%s]\n", address);
		}
	} else {
		memcpy(a, h->h_addr, h->h_length);
	}
}

/* Log format is:

   + client_ip server_ip request
*/
static void netlog(int fd, int i, unsigned char *r, int n)
{
	int j, k;
	char b[1024];
	if (debuglevel) debug("netlog(%d, %d, %p, %d)", fd, i, r, n);
	strcpy(b, "+ ");
	k = 2;
	strcpy(b+k, inet_ntoa(clients[conns[i].clt].addr));
	k += strlen(b+k);
	b[k++] = ' ';
	strcpy(b+k, inet_ntoa(servers[conns[i].index].addr));
	k += strlen(b+k);
	b[k++] = ' ';

	/* We have already used k bytes from b. This means that we want
	   no more than (sizeof b-(k+1)) bytes from r. The +1 is for the
	   trailing newline.
	*/
	j = sizeof b-(k+1);
	if (n > j) n = j;
	for (j = 0; j < n && r[j] != '\r' && r[j] != '\n'; j++) {
		b[k++] = r[j];
	}
	b[k++] = '\n';
	sendto(fd, b, k, 0, (struct sockaddr *)&logserver, sizeof logserver);
}

/* Log format is:

    client_ip timestamp server_ip request
*/
static void log(FILE *fp, int i, unsigned char *b, int n)
{
	int j;
	if (n > KEEP_MAX) n = KEEP_MAX;
	fprintf(fp, "%s ", inet_ntoa(clients[conns[i].clt].addr));
	fprintf(fp, "%ld ", (long)time(NULL));
	fprintf(fp, "%s ", inet_ntoa(servers[conns[i].index].addr));
	for (j = 0; j < n && b[j] != '\r' && b[j] != '\n'; j++) {
		fprintf(fp, "%c", isascii(b[j])?b[j]:'.');
	}
	fprintf(fp, "\n");
}

static int copy_up(int i)
{
	int rc;
	int from = conns[i].downfd;
	int to = conns[i].upfd;
	int serverindex = conns[i].index;

	char b[BUFFER_MAX];

	rc = read(from, b, BUFFER_MAX);

	if (debuglevel > 1) debug("copy_up(%d) %d bytes", i, rc);
	if (debuglevel > 2) dump(b, rc);

	if (rc <= 0) {
		return -1;
	} else {
		int n;
		if (logfp) {
			log(logfp, i, b, rc);
			if (debuglevel > 2) log(stderr, i, b, rc);
		}
		if (logsock != -1) {
			netlog(logsock, i, b, rc);
		}

		if (delayed_forward) n = 0;
		else n = write(to, b, rc);

		if (n < 0) {
			if (!nblock || errno != EAGAIN) return -1;
			n = 0;
		}
		if (n != rc) {
			if (debuglevel > 1) {
				debug("copy_up saving %d bytes in up buffer",
					rc-n);
			}
			conns[i].upn = rc-n;
			conns[i].upbptr = conns[i].upb = pen_malloc(rc-n);
			memcpy(conns[i].upb, b+n, rc-n);
		}
		servers[serverindex].sx += rc;
		clients[conns[i].clt].crx += rc;
	}
	return 0;
}

static int copy_down(int i)
{
	int rc;
	int from = conns[i].upfd;
	int to = conns[i].downfd;
	int serverindex = conns[i].index;

	char b[BUFFER_MAX];

	rc = read(from, b, BUFFER_MAX);

	if (debuglevel > 1) debug("copy_down(%d) %d bytes", i, rc);
	if (debuglevel > 2) dump(b, rc);

	if (rc <= 0) {
		return -1;
	} else {
		int n;

		if (delayed_forward) n = 0;
		else n = write(to, b, rc);

		if (n < 0) {
			if (!nblock || errno != EAGAIN) return -1;
			n = 0;
		}
		if (n != rc) {
			if (debuglevel > 1) {
				debug("copy_down saving %d bytes in down buffer",
					rc-n);
			}
			conns[i].downn = rc-n;
			conns[i].downbptr = conns[i].downb = pen_malloc(rc-n);
			memcpy(conns[i].downb, b+n, rc-n);
		}
		servers[serverindex].rx += rc;
		clients[conns[i].clt].csx += n;
	}
	return 0;
}

static void alarm_handler(int dummy)
{
	;
}

static void store_conn(int downfd, int upfd, int clt, int index)
{
	int i, fl;

	if (nblock) {
		if ((fl = fcntl(downfd, F_GETFL, 0)) == -1)
			error("Can't fcntl, errno = %d", errno);
		if (fcntl(downfd, F_SETFL, fl | O_NONBLOCK) == -1)
			error("Can't fcntl, errno = %d", errno);
		if ((fl = fcntl(upfd, F_GETFL, 0)) == -1)
			error("Can't fcntl, errno = %d", errno);
		if (fcntl(upfd, F_SETFL, fl | O_NONBLOCK) == -1)
			error("Can't fcntl, errno = %d", errno);
	}
	for (i = 0; i < connections_max; i++) {
		if (conns[i].upfd == -1) break;
	}
	if (i < connections_max) {
		conns[i].upfd = upfd;
		conns[i].downfd = downfd;
		conns[i].clt = clt;
		conns[i].index = index;
		current = index;
	} else if (debuglevel) {
		debug("Connection table full (%d slots), can't store connection.\n"
		      "Try restarting with -x %d",
		      connections_max, 2*connections_max);
		close(downfd);
		close(upfd);
	}
}

static void close_conn(int i)
{
	int index = conns[i].index;
	servers[index].c -= 1;
	if (conns[i].upfd > 0) close(conns[i].upfd);
	if (conns[i].downfd > 0) close(conns[i].downfd);
	conns[i].upfd = conns[i].downfd = -1;
	if (conns[i].downn) {
		free(conns[i].downb);
		conns[i].downn=0;
	}
	if (conns[i].upn) {
		free(conns[i].upb);
		conns[i].upn=0;
	}
}


static void usage(void)
{
	printf("usage:\n"
	       "  pen [-C addr:port] [-b sec] [-c N] [-e host[:port]] \\\n"
	       "          [-t sec] [-x N] [-w dir] [-Padfhrs] \\\n"
	       "          [host:]port h1[:p1[:maxc1[:hard1]]] [h2[:p2[:maxc2[:hard2]]]] ...\n"
	       "\n"
	       "  -C port   control port\n"
	       "  -P        use poll() rather than select()\n"
	       "  -a        debugging dumps in ascii format\n"
	       "  -b sec    blacklist time in seconds [%d]\n"
	       "  -c N      max number of clients [%d]\n"
	       "  -d        debugging on (repeat -d for more)\n"
	       "  -e host:port emergency server of last resourt\n"
	       "  -f        stay in foregound\n"
	       "  -h        use hash for initial server selection\n"
	       "  -j dir    run in chroot\n"
	       "  -l file   logging on\n"
	       "  -n        do not make sockets nonblocking\n"
	       "  -r        bypass client tracking in server selection\n"
	       "  -s        stubborn selection, i.e. don't fail over\n"
	       "  -t sec    connect timeout in seconds [%d]\n"
	       "  -u user   run as alternative user\n"
	       "  -p file   write pid to file\n"
	       "  -x N      max number of simultaneous connections [%d]\n"
	       "  -w file   save statistics in HTML format in a file\n"
	       "\n"
	       "example:\n"
	       "  pen smtp mailhost1:smtp mailhost2:25 mailhost3\n"
	       "\n",
	       BLACKLIST_TIME, CLIENTS_MAX, TIMEOUT, CONNECTIONS_MAX);

	exit(0);
}

static void background(void)
{
#ifdef HAVE_DAEMON
	daemon(0, 0);
#else
	int childpid;
	if ((childpid = fork()) < 0) {
		error("Can't fork");
	} else {
		if (childpid > 0) {
			exit(0);	/* parent */
		}
	}
	setsid();
	signal(SIGCHLD, SIG_IGN);
#endif
}

static void init(int argc, char **argv)
{
	int i;
	int server;

	conns = pen_calloc(connections_max, sizeof *conns);
	clients = pen_calloc(clients_max, sizeof *clients);
	servers = pen_calloc(argc+1, sizeof *servers);

	nservers = 0;
	current = 0;

	server = 0;

	for (i = 1; i < argc; i++) {
		servers[server].status = 0;
		servers[server].c = 0;	/* connections... */
		setaddress(&servers[server].addr, &servers[server].port,
			   argv[i], port,
			   &servers[server].maxc, &servers[server].hard);
		servers[server].sx = 0;
		servers[server].rx = 0;

		nservers++;
		server++;
	}
	if (e_server) {
		emerg_server = server;
		servers[server].status = 0;
		servers[server].c = 0;	/* connections... */
		setaddress(&servers[server].addr, &servers[server].port,
			   e_server, port,
			   &servers[server].maxc, &servers[server].hard);
		servers[server].sx = 0;
		servers[server].rx = 0;
		server++;
	}

	for (i = 0; i < clients_max; i++) {
		clients[i].last = 0;
		clients[i].addr.s_addr = 0;
		clients[i].cno = 0;
		clients[i].connects = 0;
		clients[i].csx = 0;
		clients[i].crx = 0;
	}
	for (i = 0; i < connections_max; i++) {
		conns[i].upfd = -1;
		conns[i].downfd = -1;
		conns[i].upn = 0;
		conns[i].downn = 0;
	}

	if (debuglevel) {
		debug("servers:");
		for (i = 0; i < nservers; i++) {
			debug("%2d %s:%d:%d:%d", i,
				inet_ntoa(servers[i].addr), servers[i].port,
				servers[i].maxc, servers[i].hard);
		}
	}
}

/* return upstream file descriptor */
/* sticky = 1 if this client has used the server before */
int try_server(int index, int sticky)
{
	struct sockaddr_in serv_addr;
        int upfd;
	int n;
        int now = (int)time(NULL);
	if (debuglevel) debug("Trying server %d at time %d", index, now);
        if (now-servers[index].status < blacklist_time) {
		if (debuglevel) debug("Server %d is blacklisted", index);
		return -1;
	}
        if (servers[index].maxc != 0 &&
            (servers[index].c >= servers[index].maxc) &&
	    (sticky == 0 || servers[index].c >= servers[index].hard)) {
		if (debuglevel)
			debug("Server %d is overloaded: sticky=%d, maxc=%d, hard=%d",
				index, sticky,
				servers[index].maxc, servers[index].hard);
		return -1;
	}
        upfd = socket(AF_INET, SOCK_STREAM, 0);
        if (upfd < 0) error("Error opening socket");
        memset(&serv_addr, 0, sizeof serv_addr);
        serv_addr.sin_family = AF_INET;
        serv_addr.sin_addr.s_addr = servers[index].addr.s_addr;
        serv_addr.sin_port = htons(servers[index].port);
        signal(SIGALRM, alarm_handler);
        alarm(timeout);
	n = connect(upfd, (struct sockaddr *)&serv_addr, sizeof serv_addr);
	alarm(0);	/* cancel scheduled timeout, if there is one */
        if (n == -1) {
		if (servers[index].status)
			debug("Server %d failed: %s", index, strerror(errno));
                servers[index].status = (int)time(NULL);
		close(upfd);
                return -1;
        }
	if (servers[index].status) {
		servers[index].status=0;
		debug("Server %d ok", index);
	}
	if (debuglevel) debug("Successful connect to server %d", index);
        return upfd;
}

static void do_ctrl(int downfd, struct sockaddr_in *cli_addr)
{
	char b[4096], *p, *q;
	int n, max_b = sizeof b;
	FILE *fp;

	n = read(downfd, b, max_b-1);
	if (n == -1) goto Done;

	b[n] = '\0';
	p = strchr(b, '\r');
	if (p) *p = '\0';
	p = strchr(b, '\n');
	if (p) *p = '\0';
	p = strtok(b, " ");
	if (p == NULL) goto Done;
	if (!strcmp(p, "ascii")) {
		asciidump = 1;
	} else if (!strcmp(p, "blacklist")) {
		p = strtok(NULL, " ");
		if (p) blacklist_time = atoi(p);
		sprintf(b, "%d\n", blacklist_time);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "block")) {
		nblock = 0;
	} else if (!strcmp(p, "clients_max")) {
		sprintf(b, "%d\n", clients_max);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "conn_max")) {
		sprintf(b, "%d\n", connections_max);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "control")) {
		sprintf(b, "%s\n", ctrlport);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "debug")) {
		p = strtok(NULL, " ");
		if (p) debuglevel = atoi(p);
		sprintf(b, "%d\n", debuglevel);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "delayed_forward")) {
		delayed_forward = 1;
	} else if (!strcmp(p, "hash")) {
		hash = 1;
	} else if (!strcmp(p, "listen")) {
		sprintf(b, "%s\n", listenport);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "log")) {
		p = strtok(NULL, " ");
		if (p) {
			free(logfile);
			logfile = pen_strdup(p);
			if (logfp) fclose(logfp);
			logfp = fopen(logfile, "w");
		}
		if (logfile) {
			sprintf(b, "%s\n", logfile);
			write(downfd, b, strlen(b));
		}
	} else if (!strcmp(p, "mode")) {
		sprintf(b, "%sblock %sdelayed_forward %shash %sroundrobin %sstubborn\n",
			nblock?"no ":"",
			delayed_forward?"":"no ",
			hash?"":"no ",
			roundrobin?"":"no ",
			stubborn?"":"no ");
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "no")) {
		p = strtok(NULL, " ");
		if (p == NULL) goto Done;
		if (!strcmp(p, "ascii")) {
			asciidump = 0;
		} else if (!strcmp(p, "block")) {
			nblock = 1;
		} else if (!strcmp(p, "delayed_forward")) {
			delayed_forward = 0;
		} else if (!strcmp(p, "hash")) {
			hash = 0;
		} else if (!strcmp(p, "log")) {
			logfile = NULL;
			fclose(logfp);
			logfp = NULL;
		} else if (!strcmp(p, "roundrobin")) {
			roundrobin = 0;
		} else if (!strcmp(p, "stubborn")) {
			stubborn = 0;
		} else if (!strcmp(p, "web_stats")) {
			webfile = NULL;
		}
	} else if (!strcmp(p, "pid")) {
		sprintf(b, "%ld\n", (long)getpid());
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "recent")) {
		time_t when = time(NULL);
		p = strtok(NULL, " ");
		if (p) when -= atoi(p);
		else when -= 300;
		for (n = 0; n < clients_max; n++) {
			if (clients[n].last < when) continue;
			sprintf(b, "%s connects %ld sx %lld rx %lld\n",
				inet_ntoa(clients[n].addr),
				clients[n].connects,
				clients[n].csx, clients[n].crx);
			write(downfd, b, strlen(b));
		}
	} else if (!strcmp(p, "roundrobin")) {
		roundrobin = 1;
	} else if (!strcmp(p, "server")) {
		p = strtok(NULL, " ");
		if (p == NULL) goto Done;
		n = atoi(p);
		if (n < 0 || n >= nservers) goto Done;
		while ((p = strtok(NULL, " ")) && (q = strtok(NULL, " "))) {
			if (!strcmp(p, "address")) {
				setipaddress(&servers[n].addr, q);
			} else if (!strcmp(p, "port")) {
				servers[n].port = atoi(q);
			} else if (!strcmp(p, "max")) {
				servers[n].maxc = atoi(q);
			} else if (!strcmp(p, "hard")) {
				servers[n].hard = atoi(q);
			} else if (!strcmp(p, "blacklist")) {
				servers[n].status = time(NULL)+atoi(q)-blacklist_time;
			}
		}
	} else if (!strcmp(p, "servers")) {
		for (n = 0; n < nservers; n++) {
			sprintf(b, "%d addr %s port %d conn %d max %d hard %d sx %llu rx %llu\n",
				n, inet_ntoa(servers[n].addr), servers[n].port,
				servers[n].c, servers[n].maxc, servers[n].hard,
				servers[n].sx, servers[n].rx);
			write(downfd, b, strlen(b));
		}
	} else if (!strcmp(p, "status")) {
		p = webfile;
		webfile = "/tmp/webfile.html";
		webstats();
		fp = fopen(webfile, "r");
		webfile = p;
		if (fp == NULL) goto Done;
		while (fgets(b, sizeof b, fp)) {
			write(downfd, b, strlen(b));
		}
		fclose(fp);
	} else if (!strcmp(p, "stubborn")) {
		stubborn = 1;
	} else if (!strcmp(p, "timeout")) {
		p = strtok(NULL, " ");
		if (p) timeout = atoi(p);
		sprintf(b, "%d\n", timeout);
		write(downfd, b, strlen(b));
	} else if (!strcmp(p, "web_stats")) {
		p = strtok(NULL, " ");
		if (p) {
			free(webfile);
			webfile = pen_strdup(p);
		}
		if (webfile) {
			sprintf(b, "%s\n", webfile);
			write(downfd, b, strlen(b));
		}
	} else {
		goto Done;
	}

Done:
	close(downfd);
}

static void add_client(int downfd, struct sockaddr_in *cli_addr)
{
        int upfd = -1, clino = -1, index = -1, n;

	if (roundrobin) {
		if (debuglevel) debug("Bypassing client tracking");
	} else {
	        clino = lookup_client(cli_addr->sin_addr);
		if (debuglevel) debug("lookup_client returns %d", clino);
		if (clino != -1) {
			index = clients[clino].cno;
			if (index != emerg_server) {
				upfd = try_server(index, 1);
				if (upfd != -1) goto Success;
			}
		}
		if (hash) {
			index = cli_addr->sin_addr.s_addr % nservers;
			upfd = try_server(index, 0);
			if (upfd != -1) goto Success;
		}
	}
	if (!stubborn) {
        	index = current;
        	do {
        	        index = (index + 1) % nservers;
        	        if ((upfd = try_server(index, 0)) != -1) goto Success;
        	} while (index != current);
	}
        /* if we get here, we're dead */
	if (emerg_server != -1) {
		debug("Using emergency server");
		if ((upfd = try_server(emerg_server, 0)) != -1) goto Success;
	}
	debug("Couldn't find a server for client");
        if (downfd != -1) close(downfd);
        if (upfd != -1) close(upfd);
        return;
Success:
        n = store_client(clino, cli_addr->sin_addr, index);
        store_conn(downfd, upfd, n, index);
        return;
}

static int open_listener(char *a)
{
	int listenfd;
	struct sockaddr_in serv_addr;
	char b[1024], *p;
	int one = 1;

	memset(&serv_addr, 0, sizeof serv_addr);
	serv_addr.sin_family = AF_INET;
	p = strchr(a, ':');
	if (p) {
		strncpy(b, a, sizeof b);
		b[sizeof b-1] = '\0';
		p = strchr(b, ':');
		*p = '\0';
		port = getport(p+1);
		setipaddress(&serv_addr.sin_addr, b);
		snprintf(b, (sizeof(b) - 1), "%s", inet_ntoa(serv_addr.sin_addr));
	} else {
		port = getport(a);
		serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
		sprintf(b, "0.0.0.0");
	}
	serv_addr.sin_port = htons(port);

	if ((listenfd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		error("can't open stream socket");
	}

	if (debuglevel) debug("local address=[%s:%d]", b, port);

	setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, (char *)&one, sizeof one);
	if (bind(listenfd, (struct sockaddr *) &serv_addr,
		 sizeof serv_addr) < 0) {
		error("can't bind local address");
	}

	listen(listenfd, 50);
	return listenfd;
}

static int flush_down(int i)
{
	int n = write(conns[i].downfd, conns[i].downbptr, conns[i].downn);

	if (debuglevel > 1) debug("flush_down(%d) %d bytes", i, n);
	if (n > 0) {
		conns[i].downn -= n;
		if (conns[i].downn == 0) 
			free(conns[i].downb);
		else
			conns[i].downbptr += n;
		clients[conns[i].clt].csx += n;
	}
	return n;
}

static int flush_up(int i)
{
	int n = write(conns[i].upfd, conns[i].upbptr, conns[i].upn);

	if (debuglevel > 1) debug("flush_up(%d) %d bytes", i, n);
	if (n > 0) {
		conns[i].upn -= n;
		if (conns[i].upn == 0)
			free(conns[i].upb);
		else
			conns[i].upbptr += n;
	}
	return n;
}

static void mainloop_select(int listenfd, int ctrlfd)
{
	int downfd, clilen;
	struct sockaddr_in cli_addr;
	fd_set w_read, w_write, w_error;
	int i, w_max;
	signal(SIGUSR1, stats);
	signal(SIGHUP, restart_log);
	signal(SIGTERM, quit);
	signal(SIGPIPE, SIG_IGN);

	loopflag = 1;

	if (debuglevel) debug("mainloop_select(%d, %d)", listenfd, ctrlfd);
	while (loopflag) {
		int n;

		if (do_stats) {
			if (webfile) webstats();
			else textstats();
			do_stats=0;
		}
		if (do_restart_log) {
			if (logfp) {
				fclose(logfp);
				logfp = fopen(logfile, "a");
				if (!logfp) 
					error("Can't open logfile %s", logfile);
			}
			do_restart_log=0;
		}
		FD_ZERO(&w_read);
		FD_ZERO(&w_write);
		FD_ZERO(&w_error);
		FD_SET(listenfd, &w_read);	/* new connections */
		w_max = listenfd+1;
		if (ctrlfd != -1) {
			FD_SET(ctrlfd, &w_read);
			if (ctrlfd > listenfd) w_max = ctrlfd+1;
		}

		/* add sockets from open connections */
		for (i = 0; i < connections_max; i++) {
			if (conns[i].downfd == -1) continue;
			if (conns[i].upn == 0) {
				FD_SET(conns[i].downfd, &w_read);
				if (conns[i].downfd+1 > w_max) {
					w_max = conns[i].downfd+1;
				}
			} else {
				FD_SET(conns[i].upfd, &w_write);
				if (conns[i].upfd+1 > w_max) {
					w_max = conns[i].upfd+1;
				}
			}
			if (conns[i].downn == 0) {
				FD_SET(conns[i].upfd, &w_read);
				if (conns[i].upfd+1 > w_max) {
					w_max = conns[i].upfd+1;
				}
			} else {
				FD_SET(conns[i].downfd, &w_write);
				if (conns[i].downfd+1 > w_max) {
					w_max = conns[i].downfd+1;
				}
			}
		}

		/* Wait for a connection from a client process. */
		n = select(w_max, &w_read, &w_write, /*&w_error*/0, 0);
		if (n < 0 && errno != EINTR) {
			perror("select");
			error("Error on select");
		}
		if (n <= 0) continue;

		if (FD_ISSET(listenfd, &w_read)) {
			clilen = sizeof cli_addr;
			downfd = accept(listenfd,
				(struct sockaddr *) &cli_addr, &clilen);
			if (downfd < 0) {
				if (debuglevel) perror("accept");
				continue;
			}
			if (clilen == 0) {
				if (debuglevel) perror("clilen");
				continue;
			}
			add_client(downfd, &cli_addr);
		}

		/* check control port */
		if (ctrlfd != -1 && FD_ISSET(ctrlfd, &w_read)) {
			clilen = sizeof cli_addr;
			downfd = accept(ctrlfd,
				(struct sockaddr *) &cli_addr, &clilen);
			if (downfd < 0) {
				if (debuglevel) perror("accept");
				continue;
			}
			if (clilen == 0) {
				if (debuglevel) perror("clilen");
				continue;
			}
			do_ctrl(downfd, &cli_addr);
		}

		/* check sockets from open connections */
		for (i = 0; i < connections_max; i++) {
			if (conns[i].downfd == -1) continue;
			if (FD_ISSET(conns[i].downfd, &w_read)) {
				if (copy_up(i) < 0) {
					close_conn(i);
					continue;
				}
			}
			if (FD_ISSET(conns[i].upfd, &w_read)) {
				if (copy_down(i) < 0) {
					close_conn(i);
					continue;
				}
			}
			if (FD_ISSET(conns[i].downfd, &w_write)) {
				if (flush_down(i) < 0) {
					close_conn(i);
					continue;
				}
			}
			if (FD_ISSET(conns[i].upfd, &w_write)) {
				if (flush_up(i) < 0) {
					close_conn(i);
					continue;
				}
			}
		}
	}
}

#ifdef HAVE_POLL
static void dump_pollfd(struct pollfd *ufds, int nfds)
{
	int i;
	for (i = 0; i < nfds; i++) {
		debug("%d: <%d,%d,%d>", i,
			ufds[i].fd, (int)ufds[i].events, (int)ufds[i].revents);
	}
}

static void mainloop_poll(int listenfd, int ctrlfd)
{
	int downfd, clilen;
	struct sockaddr_in cli_addr;
	struct pollfd *ufds;
	int i, j, nfds;
	short downevents, upevents;
	signal(SIGUSR1, stats);
	signal(SIGHUP, restart_log);
	signal(SIGTERM, quit);
	signal(SIGPIPE, SIG_IGN);

	loopflag = 1;

	ufds = pen_malloc((connections_max*2+2)*sizeof *ufds);

debug("POLLIN = %d, POLLOUT = %d", (int)POLLIN, (int)POLLOUT);
	if (debuglevel) debug("mainloop_poll(%d, %d)", listenfd, ctrlfd);
	while (loopflag) {
		int n;

		if (do_stats) {
			if (webfile) webstats();
			else textstats();
			do_stats=0;
		}
		if (do_restart_log) {
			if (logfp) {
				fclose(logfp);
				logfp = fopen(logfile, "a");
				if (!logfp) 
					error("Can't open logfile %s", logfile);
			}
			do_restart_log=0;
		}
		nfds = 0;
		ufds[nfds].fd = listenfd;
		ufds[nfds++].events = POLLIN;	/* new connections */
		if (ctrlfd != -1) {
			ufds[nfds].fd = ctrlfd;
			ufds[nfds++].events = POLLIN;
		}

		/* add sockets from open connections */
debug("filling pollfd structure");
		for (i = 0; i < connections_max; i++) {
			if (conns[i].downfd == -1) continue;
			upevents = downevents = 0;

			if (conns[i].upn == 0) downevents |= POLLIN;
			else upevents |= POLLOUT;

			if (conns[i].downn == 0) upevents |= POLLIN;
			else downevents |= POLLOUT;

			if (downevents) {
				ufds[nfds].fd = conns[i].downfd;
				ufds[nfds++].events = downevents;
			}
			if (upevents) {
				ufds[nfds].fd = conns[i].upfd;
				ufds[nfds++].events = upevents;
			}
		}

dump_pollfd(ufds, nfds);

		/* Wait for a connection from a client process. */
		n = poll(ufds, nfds, -1);
debug("n = %d", n);
		if (n < 0 && errno != EINTR) {
			perror("poll");
			error("Error on poll");
		}
		if (n <= 0) continue;

dump_pollfd(ufds, nfds);
		j = 0;
debug("checking pollfd structure");
debug("revents[%d] = %d", j, (int)ufds[j].revents);
		if (ufds[j].revents & POLLIN) {
			clilen = sizeof cli_addr;
			downfd = accept(listenfd,
				(struct sockaddr *) &cli_addr, &clilen);
			if (downfd < 0) {
				if (debuglevel) perror("accept");
				continue;
			}
			if (clilen == 0) {
				if (debuglevel) perror("clilen");
				continue;
			}
			add_client(downfd, &cli_addr);
		}
		j++;

		/* check control port */
		if (ctrlfd != -1 && (ufds[j++].revents & POLLIN)) {
debug("revents[%d] = %d", j-1, (int)POLLIN);
			clilen = sizeof cli_addr;
			downfd = accept(ctrlfd,
				(struct sockaddr *) &cli_addr, &clilen);
			if (downfd < 0) {
				if (debuglevel) perror("accept");
				continue;
			}
			if (clilen == 0) {
				if (debuglevel) perror("clilen");
				continue;
			}
			do_ctrl(downfd, &cli_addr);
			j++;
		}

		/* check sockets from open connections */
		for (i = 0; i < connections_max; i++) {
			if (conns[i].downfd == -1) continue;

			if (conns[i].downfd != ufds[j].fd) downevents = 0;
			else downevents = ufds[j++].revents;

			if (conns[i].upfd != ufds[j].fd) upevents = 0;
			else upevents = ufds[j++].revents;

debug("conn = %d, upevents = %d, downevents = %d", i, upevents, downevents);
if (downevents || upevents) {
debug("down[%d] = %d, up[%d] = %d", i, downevents, i, upevents);
}
			if (downevents & POLLIN) {
				if (copy_up(i) < 0) {
					close_conn(i);
					continue;
				}
			}
			if (upevents & POLLIN) {
				if (copy_down(i) < 0) {
					close_conn(i);
					continue;
				}
			}
			if (downevents & POLLOUT) {
				if (flush_down(i) < 0) {
					close_conn(i);
					continue;
				}
			}
			if (upevents & POLLOUT) {
				if (flush_up(i) < 0) {
					close_conn(i);
					continue;
				}
			}
		}
	}
}
#endif

static int options(int argc, char **argv)
{
	int c;

	while ((c = getopt(argc, argv, "C:b:c:e:j:l:p:t:u:w:x:DPadfhnrs")) != -1) {
		switch (c) {
		case 'C':
			ctrlport = optarg;
			break;
		case 'D':
			delayed_forward = 1;
			break;
		case 'P':
			use_poll = 1;
			break;
		case 'a':
			asciidump = 1;
			break;
		case 'b':
			blacklist_time = atoi(optarg);
			break;
		case 'c':
			clients_max = atoi(optarg);
			break;
		case 'd':
			debuglevel++;
			break;
		case 'e':
			e_server = optarg;
			break;
		case 'f':
			foreground = 1;
			break;
		case 'h':
			hash = 1;
			break;
		case 'j':
			jail = optarg;
			break;
		case 'l':
			logfile = pen_strdup(optarg);
			break;
		case 'n':
			nblock = 0;
			break;
		case 'p':
			pidfile = optarg;
			break;
		case 'r':
			roundrobin = 1;
			break;
		case 's':
			stubborn = 1;
			break;
		case 't':
			timeout = atoi(optarg);
			if (timeout < 1) {
				usage();
			}
			break;
		case 'u':
			user = optarg;
			break;
		case 'x':
			connections_max = atoi(optarg);
			break;
		case 'w':
			webfile = pen_strdup(optarg);
			break;
		case '?':
		default:
			usage();
		}
	}

	return optind;
}

int main(int argc, char **argv)
{
	int i, listenfd, ctrlfd;
	struct passwd *pwd = NULL;
	struct rlimit r;
	int n = options(argc, argv);
	argc -= n;
	argv += n;

	if (argc < 1) {
		usage();
	}

	if ((connections_max*2+10) > FD_SETSIZE && !use_poll) 
		error("Number of simultaneous connections to large.\n"
		      "Maximum is %d, or re-build pen with larger FD_SETSIZE",
		      (FD_SETSIZE-10)/2);
	
	getrlimit(RLIMIT_CORE, &r);
	r.rlim_cur = r.rlim_max;
	setrlimit(RLIMIT_CORE, &r);

	signal(SIGCHLD, SIG_IGN);

	if (!foreground) {
		background();
	}

	/* we must open listeners before dropping privileges */
	if (ctrlport) ctrlfd = open_listener(ctrlport);
	else ctrlfd = -1;

	listenport = argv[0];
	listenfd = open_listener(listenport);
	init(argc, argv);

	/* we must look up user id before chrooting */
	if (user) {
		if (debuglevel) debug("Run as user %s", user);
		pwd = getpwnam(user);
		if (pwd == NULL) error("Can't getpwnam(%s)", user);
	}

	/* we must chroot before dropping privileges */
	if (jail) {
		if (debuglevel) debug("Run in %s", jail);
		if (chroot(jail) == -1) error("Can't chroot(%s)", jail);
	}

	/* ready to defang ourselves */
	if (pwd) {
		if (setuid(pwd->pw_uid) == -1)
			error("Can't setuid(%d)", (int)pwd->pw_uid);
	}

	if (logfile) {
		char *p = strchr(logfile, ':');
		if (p && logfile[0] != '/') {	/* log to net */
			struct hostent *hp;
			if (debuglevel) debug("net log to %s", logfile);
			*p++ = '\0';
			logsock = socket(PF_INET, SOCK_DGRAM, 0);
			if (logsock < 0) error("Can't create log socket");
			logserver.sin_family = AF_INET;
			hp = gethostbyname(logfile);
			if (hp == NULL) error("Bogus host %s", logfile);
			memcpy(&logserver.sin_addr.s_addr,
				hp->h_addr, hp->h_length);
			logserver.sin_port = htons(atoi(p));
		} else {	/* log to file */
			if (debuglevel) debug("file log to %s", logfile);
			logfp = fopen(logfile, "a");
			if (!logfp) error("Can't open logfile %s", logfile);
		}
	}
	if (pidfile) {
		pidfp = fopen(pidfile, "w");
		if (!pidfp) {
			error("Can't create pidfile %s", pidfile);
			exit(1);
		}
		fprintf(pidfp, "%d", (int)getpid());
		fclose(pidfp);
	}

#ifdef HAVE_POLL
	if (use_poll) mainloop_poll(listenfd, ctrlfd);
#else
	if (use_poll) error("You don't have poll()");
#endif
	if (!use_poll) mainloop_select(listenfd, ctrlfd);

	if (debuglevel) debug("Exiting, cleaning up...");
	if (logfp) fclose(logfp);
	for (i = 0; i < connections_max; i++) {
		close_conn(i);
	}
	close(listenfd);
	if (pidfile) {
		unlink(pidfile);
	}
	return 0;
}
