/*
 * jabberd - Jabber Open Source Server
 * Copyright (c) 2002 Jeremie Miller, Thomas Muldowney,
 *                    Ryan Eatmon, Robert Norris
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA02111-1307USA
 */

#include "s2s.h"

static sig_atomic_t s2s_shutdown = 0;
sig_atomic_t s2s_lost_router = 0;
static sig_atomic_t s2s_logrotate = 0;

static void _s2s_signal(int signum) {
    s2s_shutdown = 1;
    s2s_lost_router = 0;
}

static void _s2s_signal_hup(int signum) {
    s2s_logrotate = 1;
}

/** store the process id */
static void _s2s_pidfile(s2s_t s2s) {
    char *pidfile;
    FILE *f;
    pid_t pid;

    pidfile = config_get_one(s2s->config, "pidfile", 0);
    if(pidfile == NULL)
        return;

    pid = getpid();

    if((f = fopen(pidfile, "w+")) == NULL) {
        log_write(s2s->log, LOG_ERR, "couldn't open %s for writing: %s", pidfile, strerror(errno));
        return;
    }

    if(fprintf(f, "%d", pid) < 0) {
        log_write(s2s->log, LOG_ERR, "couldn't write to %s: %s", pidfile, strerror(errno));
        return;
    }

    fclose(f);

    log_write(s2s->log, LOG_INFO, "process id is %d, written to %s", pid, pidfile);
}

/** pull values out of the config file */
static void _s2s_config_expand(s2s_t s2s) {
    char *str, secret[41];
    int i, r;

    s2s->id = config_get_one(s2s->config, "id", 0);
    if(s2s->id == NULL)
        s2s->id = "s2s";

    s2s->router_ip = config_get_one(s2s->config, "router.ip", 0);
    if(s2s->router_ip == NULL)
        s2s->router_ip = "127.0.0.1";

    s2s->router_port = j_atoi(config_get_one(s2s->config, "router.port", 0), 5347);

    s2s->router_user = config_get_one(s2s->config, "router.user", 0);
    if(s2s->router_user == NULL)
        s2s->router_user = "jabberd";
    s2s->router_pass = config_get_one(s2s->config, "router.pass", 0);
    if(s2s->router_pass == NULL)
        s2s->router_pass = "secret";

    s2s->router_pemfile = config_get_one(s2s->config, "router.pemfile", 0);

    s2s->retry_init = j_atoi(config_get_one(s2s->config, "router.retry.init", 0), 3);
    s2s->retry_lost = j_atoi(config_get_one(s2s->config, "router.retry.lost", 0), 3);
    if((s2s->retry_sleep = j_atoi(config_get_one(s2s->config, "router.retry.sleep", 0), 2)) < 1)
        s2s->retry_sleep = 1;

    s2s->router_default = config_count(s2s->config, "router.non-default") ? 0 : 1;

    s2s->log_type = log_STDOUT;
    if(config_get(s2s->config, "log") != NULL) {
        if((str = config_get_attr(s2s->config, "log", 0, "type")) != NULL) {
            if(strcmp(str, "file") == 0)
                s2s->log_type = log_FILE;
            else if(strcmp(str, "syslog") == 0)
                s2s->log_type = log_SYSLOG;
        }
    }

    if(s2s->log_type == log_SYSLOG) {
        s2s->log_facility = config_get_one(s2s->config, "log.facility", 0);
        s2s->log_ident = config_get_one(s2s->config, "log.ident", 0);
        if(s2s->log_ident == NULL)
            s2s->log_ident = "jabberd/s2s";
    } else if(s2s->log_type == log_FILE)
        s2s->log_ident = config_get_one(s2s->config, "log.file", 0);

    s2s->local_ip = config_get_one(s2s->config, "local.ip", 0);
    if(s2s->local_ip == NULL)
        s2s->local_ip = "0.0.0.0";

    s2s->local_port = j_atoi(config_get_one(s2s->config, "local.port", 0), 0);

    s2s->local_resolver = config_get_one(s2s->config, "local.resolver", 0);
    if(s2s->local_resolver == NULL)
        s2s->local_resolver = "resolver";

    if(config_get(s2s->config, "local.secret") != NULL)
        s2s->local_secret = strdup(config_get_one(s2s->config, "local.secret", 0));
    else {
        for(i = 0; i < 40; i++) {
            r = (int) (36.0 * rand() / RAND_MAX);
            secret[i] = (r >= 0 && r <= 9) ? (r + 48) : (r + 87);
        }
        secret[40] = '\0';

        s2s->local_secret = strdup(secret);
    }

    if(s2s->local_secret == NULL)
        s2s->local_secret = "secret";

    s2s->local_pemfile = config_get_one(s2s->config, "local.pemfile", 0);
    if (s2s->local_pemfile != NULL)
 	log_debug(ZONE,"loaded local pemfile for peer s2s connections");

    s2s->check_interval = j_atoi(config_get_one(s2s->config, "check.interval", 0), 0);
    s2s->check_queue = j_atoi(config_get_one(s2s->config, "check.queue", 0), 0);
    s2s->check_invalid = j_atoi(config_get_one(s2s->config, "check.invalid", 0), 0);
    s2s->check_keepalive = j_atoi(config_get_one(s2s->config, "check.keepalive", 0), 0);
}

static int _s2s_router_connect(s2s_t s2s) {
    log_write(s2s->log, LOG_NOTICE, "attempting connection to router at %s, port=%d", s2s->router_ip, s2s->router_port);

    s2s->fd = mio_connect(s2s->mio, s2s->router_port, s2s->router_ip, s2s_router_mio_callback, (void *) s2s);
    if(s2s->fd < 0) {
        if(errno == ECONNREFUSED)
            s2s_lost_router = 1;
        log_write(s2s->log, LOG_NOTICE, "connection attempt to router failed: %s (%d)", strerror(errno), errno);
        return 1;
    }

    s2s->router = sx_new(s2s->sx_env, s2s->fd, s2s_router_sx_callback, (void *) s2s);
    sx_client_init(s2s->router, 0, NULL, NULL, NULL, "1.0");

    return 0;
}

static void _s2s_time_checks(s2s_t s2s) {
    conn_t conn;
    time_t now;
    char *domain, ipport[INET6_ADDRSTRLEN + 17], *key;
    jqueue_t q;
    dnscache_t dns;
    pkt_t pkt;
    conn_state_t state;

    now = time(NULL);

    /* queue expiry */
    if(s2s->check_queue > 0) {
        if(xhash_iter_first(s2s->outq))
            do {
                xhash_iter_get(s2s->outq, (const char **) &domain, (void **) &q);

                log_debug(ZONE, "running time checks for %s", domain);

                /* dns lookup timeout check first */
                dns = xhash_get(s2s->dnscache, domain);
                if(dns == NULL)
                    continue;

                if(dns->pending) {
                    log_debug(ZONE, "dns lookup pending for %s", domain);
                    if(now > dns->init_time + s2s->check_queue) {
                        log_debug(ZONE, "dns lookup expired for %s, bouncing packets in queue", domain);
                        /* bounce queue */
                        while((pkt = jqueue_pull(q)) != NULL) {
                            if(pkt->nad->ecur > 1 && NAD_NURI_L(pkt->nad, NAD_ENS(pkt->nad, 1)) == strlen(uri_CLIENT) && strncmp(NAD_NURI(pkt->nad, NAD_ENS(pkt->nad, 1)), uri_CLIENT, strlen(uri_CLIENT)) == 0)
                                sx_nad_write(s2s->router, stanza_tofrom(stanza_tofrom(stanza_error(pkt->nad, 1, stanza_err_REMOTE_SERVER_TIMEOUT), 1), 0));
                            else
                                nad_free(pkt->nad);

                            jid_free(pkt->to);
                            jid_free(pkt->from);
                            free(pkt);
                        }

                        /* expire pending dns entry */
                        xhash_zap(s2s->dnscache, dns->name);
                        free(dns);
                    }

                    continue;
                }

                /* generate the ip/port pair */
                snprintf(ipport, INET6_ADDRSTRLEN + 16, "%s/%d", dns->ip, dns->port);

                /* get the conn */
                conn = xhash_get(s2s->out, ipport);
                if(conn == NULL) {
                    /* no pending conn? perhaps it failed? */
                    log_debug(ZONE, "no pending connection for %s, bouncing queue", domain);

                    /* bounce queue */
                    while((pkt = jqueue_pull(q)) != NULL) {
                        if(pkt->nad->ecur > 1 && NAD_NURI_L(pkt->nad, NAD_ENS(pkt->nad, 1)) == strlen(uri_CLIENT) && strncmp(NAD_NURI(pkt->nad, NAD_ENS(pkt->nad, 1)), uri_CLIENT, strlen(uri_CLIENT)) == 0)
                            sx_nad_write(s2s->router, stanza_tofrom(stanza_tofrom(stanza_error(pkt->nad, 1, stanza_err_REMOTE_SERVER_TIMEOUT), 1), 0));
                        else
                            nad_free(pkt->nad);

                        jid_free(pkt->to);
                        jid_free(pkt->from);
                        free(pkt);
                    }

                    continue;
                }

                /* connect timeout check */
                if(!conn->online && now > conn->init_time + s2s->check_queue) {
                    log_debug(ZONE, "connection to %s is not online yet, bouncing queue", domain);
                    /* !!! kill conn if necessary */
                    /* bounce queue */
                    while((pkt = jqueue_pull(q)) != NULL) {
                        if(pkt->nad->ecur > 1 && NAD_NURI_L(pkt->nad, NAD_ENS(pkt->nad, 1)) == strlen(uri_CLIENT) && strncmp(NAD_NURI(pkt->nad, NAD_ENS(pkt->nad, 1)), uri_CLIENT, strlen(uri_CLIENT)) == 0)
                            sx_nad_write(s2s->router, stanza_tofrom(stanza_tofrom(stanza_error(pkt->nad, 1, stanza_err_REMOTE_SERVER_TIMEOUT), 1), 0));
                        else
                            nad_free(pkt->nad);

                        jid_free(pkt->to);
                        jid_free(pkt->from);
                        free(pkt);
                    }
                }
            } while(xhash_iter_next(s2s->outq));
    }

    /* invalid expiry */
    if(s2s->check_invalid > 0 && now > s2s->last_invalid_check + s2s->check_invalid) {
        if(xhash_iter_first(s2s->out))
            do {
                xhash_iter_get(s2s->out, (const char **) &key, (void *) &conn);
                log_debug(ZONE, "checking connection state for %s", key);
                if(xhash_iter_first(conn->states))
                    do {
                        xhash_iter_get(conn->states, NULL, (void *) &state);

                        /* drop invalid */
                        if(state == conn_INVALID) {
                            xhash_zap(conn->states, key);
                            log_debug(ZONE, "dropping invalid connection for conn key %s", key);
                        }

                    } while(xhash_iter_next(conn->states));
            } while(xhash_iter_next(s2s->out));
    }

    /* keepalives */
    if(xhash_iter_first(s2s->out))
        do {
            xhash_iter_get(s2s->out, NULL, (void **) &conn);

            if(s2s->check_keepalive > 0 && conn->last_activity > 0 && now > conn->last_activity + s2s->check_keepalive && conn->s->state >= state_STREAM) {
                log_debug(ZONE, "sending keepalive for %d", conn->fd);

                sx_raw_write(conn->s, " ", 1);
                
                mio_write(s2s->mio, conn->fd);
            }
        } while(xhash_iter_next(s2s->out));
}

int main(int argc, char **argv) {
    s2s_t s2s;
    char *config_file;
    int optchar;
#ifdef POOL_DEBUG
    time_t pool_time = 0;
#endif

#ifdef HAVE_UMASK
    umask((mode_t) 0027);
#endif

    srand(time(NULL));

#ifdef HAVE_WINSOCK2_H
/* get winsock running */
	{
		WORD wVersionRequested;
		WSADATA wsaData;
		int err;
		
		wVersionRequested = MAKEWORD( 2, 2 );
		
		err = WSAStartup( wVersionRequested, &wsaData );
		if ( err != 0 ) {
            /* !!! tell user that we couldn't find a usable winsock dll */
			return 0;
		}
	}
#endif

    jabber_signal(SIGINT, _s2s_signal);
    jabber_signal(SIGTERM, _s2s_signal);
#ifdef SIGHUP
    jabber_signal(SIGHUP, _s2s_signal_hup);
#endif
#ifdef SIGPIPE
    jabber_signal(SIGPIPE, SIG_IGN);
#endif

    s2s = (s2s_t) malloc(sizeof(struct s2s_st));
    memset(s2s, 0, sizeof(struct s2s_st));

    /* load our config */
    s2s->config = config_new();

    config_file = CONFIG_DIR "/s2s.xml";

    /* cmdline parsing */
    while((optchar = getopt(argc, argv, "Dc:h?")) >= 0)
    {
        switch(optchar)
        {
            case 'c':
                config_file = optarg;
                break;
            case 'D':
#ifdef DEBUG
                set_debug_flag(1);
#else
                printf("WARN: Debugging not enabled.  Ignoring -D.\n");
#endif
                break;
            case 'h': case '?': default:
                fputs(
                    "s2s - jabberd server-to-server connector (" VERSION ")\n"
                    "Usage: s2s <options>\n"
                    "Options are:\n"
                    "   -c <config>     config file to use [default: " CONFIG_DIR "/s2s.xml]\n"
#ifdef DEBUG
                    "   -D              Show debug output\n"
#endif
                    ,
                    stdout);
                config_free(s2s->config);
                free(s2s);
                return 1;
        }
    }

    if(config_load(s2s->config, config_file) != 0) {
        fputs("s2s: couldn't load config, aborting\n", stderr);
        config_free(s2s->config);
        free(s2s);
        return 2;
    }

    _s2s_config_expand(s2s);

    s2s->log = log_new(s2s->log_type, s2s->log_ident, s2s->log_facility);
    log_write(s2s->log, LOG_NOTICE, "starting up");

    _s2s_pidfile(s2s);

    s2s->outq = xhash_new(401);
    s2s->out = xhash_new(401);
    s2s->in = xhash_new(401);
    s2s->dnscache = xhash_new(401);

    s2s->pc = prep_cache_new();

    s2s->dead = jqueue_new();

    s2s->sx_env = sx_env_new();

#ifdef HAVE_SSL
    /* get the ssl context up and running */
    if(s2s->local_pemfile != NULL) {
        s2s->sx_ssl = sx_env_plugin(s2s->sx_env, sx_ssl_init, s2s->local_pemfile, NULL);
        if(s2s->sx_ssl == NULL) {
            log_write(s2s->log, LOG_ERR, "failed to load local SSL pemfile, SSL will not be available to peers");
            s2s->local_pemfile = NULL;
        } else
            log_debug(ZONE, "loaded pemfile for SSL connections to peers");
    }

    /* try and get something online, so at least we can encrypt to the router */
    if(s2s->sx_ssl == NULL && s2s->router_pemfile != NULL) {
        s2s->sx_ssl = sx_env_plugin(s2s->sx_env, sx_ssl_init, s2s->router_pemfile, NULL);
        if(s2s->sx_ssl == NULL) {
            log_write(s2s->log, LOG_ERR, "failed to load router SSL pemfile, channel to router will not be SSL encrypted");
            s2s->router_pemfile = NULL;
        }
    }
#endif

    /* get sasl online */
    s2s->sx_sasl = sx_env_plugin(s2s->sx_env, sx_sasl_init, NULL, NULL, 0);
    if(s2s->sx_sasl == NULL) {
        log_write(s2s->log, LOG_ERR, "failed to initialise SASL context, aborting");
        exit(1);
    }
            
    s2s->sx_db = sx_env_plugin(s2s->sx_env, s2s_db_init);

    s2s->mio = mio_new(1024);

    s2s->retry_left = s2s->retry_init;
    _s2s_router_connect(s2s);

    while(!s2s_shutdown) {
        mio_run(s2s->mio, 5);

        if(s2s_logrotate) {
            log_write(s2s->log, LOG_NOTICE, "reopening log ...");
            log_free(s2s->log);
            s2s->log = log_new(s2s->log_type, s2s->log_ident, s2s->log_facility);
            log_write(s2s->log, LOG_NOTICE, "log started");

            s2s_logrotate = 0;
        }

        if(s2s_lost_router) {
            if(s2s->retry_left < 0) {
                log_write(s2s->log, LOG_NOTICE, "attempting reconnect");
                sleep(s2s->retry_sleep);
                s2s_lost_router = 0;
                _s2s_router_connect(s2s);
            }

            else if(s2s->retry_left == 0) {
                s2s_shutdown = 1;
            }

            else {
                log_write(s2s->log, LOG_NOTICE, "attempting reconnect (%d left)", s2s->retry_left);
                s2s->retry_left--;
                sleep(s2s->retry_sleep);
                s2s_lost_router = 0;
                _s2s_router_connect(s2s);
            }
        }
            
        /* cleanup dead sx_ts */
        while(jqueue_size(s2s->dead) > 0)
            sx_free((sx_t) jqueue_pull(s2s->dead));

        /* time checks */
        if(s2s->check_interval > 0 && time(NULL) >= s2s->next_check) {
            log_debug(ZONE, "running time checks");

            _s2s_time_checks(s2s);

            s2s->next_check = time(NULL) + s2s->check_interval;
            log_debug(ZONE, "next time check at %d", s2s->next_check);
        }

#ifdef POOL_DEBUG
        if(time(NULL) > pool_time + 60) {
            pool_stat(1);
            pool_time = time(NULL);
        }
#endif
    }

    log_write(s2s->log, LOG_NOTICE, "shutting down");

    /* !!! close conns */

    /* !!! walk & free resolve queues */

    xhash_free(s2s->outq);
    xhash_free(s2s->out);
    xhash_free(s2s->in);
    xhash_free(s2s->dnscache);

    prep_cache_free(s2s->pc);

    sx_free(s2s->router);

    sx_env_free(s2s->sx_env);

    mio_free(s2s->mio);

    log_free(s2s->log);

    config_free(s2s->config);

    free(s2s->local_secret);
    free(s2s);

#ifdef POOL_DEBUG
    pool_stat(1);
#endif

    return 0;
}
