/* $Id: ipfreely.c,v 1.2 2005/03/27 15:20:39 nialloh Exp $ */

/*
 * Copyright (c) 2005 Niall O'Higgins <niallo@netsoc.ucd.ie>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF MIND, USE, DATA OR PROFITS, WHETHER
 * IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
 * OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/wait.h>

#include <err.h>
#include <errno.h>
#include <poll.h>
#include <pwd.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <unistd.h>

#include <netdb.h>

#define BUFSIZE         65536
#define MAX_BACKLOG     5

void    chroot_jail(void);
int     create_server_socket(char *, char *);
void    daemonise(void);
void    debug(const char *format, ...);
void    drop_privs(void);
void    sigchild(int);
void    sigterm(int);
ssize_t transfer(int, int, size_t);
void    usage(void);

extern char *optarg;
extern int optind;

int   debug_flag = 0;
char  username[10];
char  local_port[6];
char  remote_port[6];
char  local_host[MAXHOSTNAMELEN];
char  remote_host[MAXHOSTNAMELEN];
char  chroot_dir[MAXPATHLEN];
uid_t _uid;

void
chroot_jail(void)
{
        if (chroot(chroot_dir) == -1)
                errx(1, "could not chroot(): %s", strerror(errno));
        if (chdir("/") == -1)
                errx(1, "could not chdir(): %s", strerror(errno));
}

/* Return file descriptor of created server socket */
int
create_server_socket(char *host, char *port)
{
        int error = 0;
        int fd;
        int option_value = 1;
        struct addrinfo hints, *res;
        if ((fd = socket(AF_INET, SOCK_STREAM, 0)) == -1)
                errx(1, "could not create server socket");
        memset(&hints, 0, sizeof(hints));
        hints.ai_family = PF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;
        error = getaddrinfo(host, port, &hints, &res);
        if (error != 0)
                errx(1, "\"%s\" - %s", host, gai_strerror(error));
        if (bind(fd, res->ai_addr, res->ai_addrlen) == -1)
                errx(1, "could not bind to %s - %s", host, strerror(errno));
        if (listen(fd, MAX_BACKLOG) == -1)
                errx(1, "could not listen on server socket");
        freeaddrinfo(res);
        error = setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
                    (char *) &option_value, sizeof(option_value));
        if (error == -1)
                errx(1, "could not set socket options");
        return fd;
}

void
daemonise(void)
{
        if (debug_flag == 0) {
                if (daemon(1, 1) == -1)
                        errx(1, "daemon(): %s", strerror(errno));
        }
}

void
debug(const char *fmt, ...)
{
        va_list ap;
        va_start(ap, fmt);
        if (debug_flag != 0) {
                (void) vfprintf(stderr,
                    fmt, ap);
        }
        va_end(ap);
}

void
drop_privs(void)
{
        if (setuid(_uid) == -1)
                errx(1, "could not setuid()");
        if (seteuid(_uid) == -1)
                errx(1, "could not seteuid()");
        if (getuid() == 0 || geteuid() == 0)
                errx(1, "do not run this program as root!");
}

void
sigchild(int sig)
{
        pid_t pid;
        int stat;
        pid = wait(&stat);
        return;
}

void
sigterm(int sig)
{
        struct syslog_data sdata = SYSLOG_DATA_INIT;
        syslog_r(LOG_INFO, &sdata, "shutting down");
        _exit(0);
}

ssize_t
transfer(int input_fd, int output_fd, size_t len)
{
        unsigned char buf[BUFSIZE];
        ssize_t bytes;

        bytes = recv(output_fd, buf, len, 0);
        if (bytes == 0)
                return 0;
        if (bytes < 0)
                debug("transfer() error");

        bytes = send(input_fd, buf, bytes, 0);
        if (bytes < 0) {
                debug("transfer() error");
                return 0;
        }
        return bytes;
}

void
usage(void)
{
        fprintf(stderr,
            "usage: ipfreely [-d] [-b localhost:port] [-r remotehost:port]\n");
        fprintf(stderr,
            "                [-c chroot dir] [-u username]\n");
        exit(1);
}

int
main(int argc, char **argv)
{
        int             error, ch, listenfd, connfd;
        pid_t           childpid;
        struct passwd   *passwd;
        char            *token;
        socklen_t       len;
        struct addrinfo hints, *res;
        struct sockaddr_storage cliaddr;
        struct syslog_data sdata = SYSLOG_DATA_INIT;

        if (argc < 2)
                usage();

        while ((ch = getopt(argc, argv, "db:r:c:u:")) != -1) {
                switch (ch) {
                case 'd':
                        debug_flag = 1;
                        break;
                case 'b':
                        token = strsep(&optarg, ":");
                        if (token == NULL)
                                errx(1, "invalid local host specification");
                        strlcpy(local_host, token, sizeof(local_host));
                        token = strsep(&optarg, ":");
                        if (token == NULL)
                                errx(1, "invalid local port specification");
                        strlcpy(local_port, token, sizeof(local_port));
                        break;
                case 'r':
                        token = strsep(&optarg, ":");
                        if (token == NULL)
                                errx(1, "invalid remote host specification");
                        strlcpy(remote_host, token, sizeof(remote_host));
                        token = strsep(&optarg, ":");
                        if (token == NULL)
                                errx(1, "invalid remote port specification");
                        strlcpy(remote_port, token, sizeof(remote_port));
                        break;
                case 'c':
                        strlcpy(chroot_dir, optarg, sizeof(chroot_dir));
                        if (chdir(chroot_dir) != 0)
                                errx(1, "chroot directory error: %s",
                                    strerror(errno));
                        break;
                case 'u':
                        strlcpy(username, optarg, sizeof(username));
                        passwd = getpwnam(username);
                        if (passwd == NULL)
                                errx(1, "user does not exist");
                        _uid = passwd->pw_uid;
                        break;
                case '?':
                default:
                        usage();
                }
        }
        argc -= optind;
        argv += optind;
        if (local_port == 0 || remote_port == 0 || _uid == 0)
                usage();
        signal(SIGTERM, sigterm);
        signal(SIGCHLD, sigchild);
        listenfd = create_server_socket(local_host, local_port);
        memset(&hints, 0, sizeof(hints));
        hints.ai_family = PF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;
        /* Get an address structure for the remote host */
        error = 0;
        error = getaddrinfo(remote_host, remote_port, &hints, &res);
        if (error != 0)
                errx(1, "could not getaddrinfo() %s", gai_strerror(error));
        chroot_jail();
        drop_privs();
        daemonise();
        setproctitle("parent listening on port %s", local_port);
        syslog_r(LOG_INFO, &sdata, "listening on port %s", local_port);
        len = sizeof(cliaddr);
        while(1) {
                connfd = accept(listenfd, (struct sockaddr *) &cliaddr,
                    &len);
                if ((childpid = fork()) == 0) {
                        int outfd, nfds;
                        char hbuf[NI_MAXHOST];
                        int done = 0, option_value = 1, bytes = 0;
                        struct pollfd pfd[2];
                        error = getnameinfo((struct sockaddr *) &cliaddr,
                            cliaddr.ss_len, hbuf, sizeof(hbuf), NULL,
                            0, NI_NUMERICHOST);
                        if (error != 0) {
                                strlcpy(hbuf, "unknown", sizeof(hbuf));
                                debug("connection from % (%s)\n",
                                    hbuf, gai_strerror(error));
                                syslog_r(LOG_INFO, &sdata, "connection from %s (%s)",
                                    hbuf, gai_strerror(error));
                        }
                        else {
                                debug("connection from %s\n", hbuf);
                                syslog_r(LOG_INFO, &sdata, "connection from %s", hbuf);
                        }
                        setproctitle("child [%s]", hbuf);
                        debug("creating outfd\n");
                        if ((outfd = socket(res->ai_family, res->ai_socktype,
                            res->ai_protocol)) == -1)
                                errx(1, "could not create socket");
                        error = setsockopt(outfd, SOL_SOCKET, SO_REUSEADDR,
                                    (char *) &option_value, sizeof(option_value));
                        if (error == -1)
                                errx(1, "could not set socket options");
                        debug("connecting outfd (%d)\n", outfd);
                        if (connect(outfd, res->ai_addr, res->ai_addrlen) == -1)
                                errx(1, "could not connect() %s",
                                strerror(error));
                        pfd[0].fd = connfd;
                        pfd[0].events = POLLIN;
                        pfd[1].fd = outfd;
                        pfd[1].events = POLLIN;
                        while(done == 0) {
                                nfds = poll(pfd, 2, -1);
                                if (pfd[0].revents & POLLIN) {
                                        bytes = transfer(pfd[1].fd, pfd[0].fd, BUFSIZE);
                                }
                                if (pfd[1].revents & POLLIN) {
                                        bytes = transfer(pfd[0].fd, pfd[1].fd, BUFSIZE);
                                }
                                if (bytes == 0)
                                        done = 1;
                        }
                        debug("shutting down connfd...\n");
                        shutdown(connfd, SHUT_WR);
                        close(connfd);
                        debug("shutting down outfd...\n");
                        shutdown(outfd, SHUT_WR);
                        close(outfd);
                        debug("exiting...\n");
                        syslog_r(LOG_INFO, &sdata, "disconnect from %s", hbuf);
                        exit(0);
                }
                if (childpid == -1)
                        errx(1, "could not fork() %s", strerror(errno));
                close(connfd);
        }
        return 0;
}
