/* $Id: rules.c,v 1.5 2003/10/06 08:49:35 dhartmei Exp $ */

/*
 * Copyright (c) 2003 Daniel Hartmeier
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *    - Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *    - Redistributions in binary form must reproduce the above
 *      copyright notice, this list of conditions and the following
 *      disclaimer in the documentation and/or other materials provided
 *      with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *
 */

static const char rcsid[] = "$Id: rules.c,v 1.5 2003/10/06 08:49:35 dhartmei Exp $";

#include <errno.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <regex.h>
#include <unistd.h>
#include <libmilter/mfapi.h>

#include "rules.h"

extern void log(int, const char *, ...);
extern void die(const char *);

#define	DEFAULT_REJECT_MSG	"Command refused"
#define DEFAULT_TEMPFAIL_MSG	"Please try again later"

static int		 action = SMFIS_REJECT;
static char		 message[128] = DEFAULT_REJECT_MSG;
static pthread_mutex_t	 mutex;

#ifndef	REG_BASIC
#define	REG_BASIC 0
#endif

#ifdef LINUX
size_t
strlcpy(char *dst, const char *src, size_t size)
{
	strncpy(dst, src, size);
	dst[size - 1] = 0;
	return (size);
}
size_t
strlcat(char *dst, const char *src, size_t size)
{
	strncat(dst, src, size);
	dst[size - 1] = 0;
	return (size);
}
#endif

int
filter_init(void)
{
	if (pthread_mutex_init(&mutex, 0))
		return (1);
	else
		return (0);
}

static void
mutex_lock(void)
{
	if (pthread_mutex_lock(&mutex))
		die("pthread_mutex_lock");
}

static void
mutex_unlock(void)
{
	if (pthread_mutex_unlock(&mutex))
		die("pthread_mutex_unlock");
}

static int
parse_cmd(const char **s, char *d, size_t size)
{
	while (**s == ' ' || **s == '\t')
		(*s)++;
	if (!**s)
		return (1);
	while (**s && **s != ' ' && **s != '\t' && size-- > 1)
		*d++ = *(*s)++;
	*d = 0;
	return (0);
}

static int
parse_str(const char **s, char *d, size_t size)
{
	while (**s == ' ' || **s == '\t')
		(*s)++;
	if (!**s)
		return (1);
	while (**s &&  size-- > 1)
		*d++ = *(*s)++;
	*d = 0;
	return (0);
}

static int
parse_arg(const char **s, struct argument *a)
{
	char c;
	const char *t;
	int r, cflags;

	while (**s == ' ' || **s == '\t')
		(*s)++;
	if (!**s)
		return (1);
	c = *(*s)++;
	t = *s;
	while (**s && **s != c)
		(*s)++;
	if (**s != c)
		return (1);
	a->s = (char *)malloc(*s - t + 1);
	if (a->s == NULL) {
		log(LOG_ERR, "parse_arg: malloc: %s", strerror(errno));
		return (1);
	}
	memcpy(a->s, t, *s - t);
	a->s[*s - t] = 0;
	(*s)++;
	while (**s == 'e' || **s == 'i' || **s == 'n') {
		switch (**s) {
		case 'e':
			a->e = 1;
			break;
		case 'i':
			a->i = 1;
			break;
		case 'n':
			a->n = 1;
			break;
		}
		(*s)++;
	}
	if (!a->s[0]) {
		if (a->e || a->i || a->n) {
			log(LOG_ERR, "parse_arg: empty expression cannot "
			    "have flags");
			return (1);
		}
	} else {
		cflags = a->e ? REG_EXTENDED : REG_BASIC;
		if (a->i)
			cflags |= REG_ICASE;
		if ((r = regcomp(&a->re, a->s, cflags))) {
			char err[8192];

			regerror(r, &a->re, err, sizeof(err));
			log(LOG_ERR, "parse_arg: regcomp: %s: %s", a->s, err);
			free(a->s);
			a->s = NULL;
			return (1);
		}
	}
	return (0);
}

static void
log_rule(const char *prefix, const char *type, const struct rule *rule)
{
	char msg[8192];
	int i;

	snprintf(msg, sizeof(msg), "%s: %s", prefix, type);
	for (i = 0; i < 2; ++i)
		if (rule->arg[i].s != NULL) {
			strlcat(msg, " /", sizeof(msg));
			strlcat(msg, rule->arg[i].s, sizeof(msg));
			strlcat(msg, "/", sizeof(msg));
			if (rule->arg[i].e)
				strlcat(msg, "e", sizeof(msg));
			if (rule->arg[i].i)
				strlcat(msg, "i", sizeof(msg));
			if (rule->arg[i].n)
				strlcat(msg, "n", sizeof(msg));
		}
	log(LOG_DEBUG, "%s", msg);
}

static int
parse_line(const char *s, struct ruleset *ruleset)
{
	char		  cmd[8192];
	struct rule	**rules = NULL;
	int		  args = 0;

	if (s[0] == 0) {
		log(LOG_DEBUG, "parse_line: ");
		return (0);
	}
	while (!parse_cmd(&s, cmd, sizeof(cmd))) {
		if (cmd[0] == '#') {
			log(LOG_DEBUG, "parse_line: %s%s", cmd, s);
			return (0);
		} else if (!strcasecmp(cmd, "discard")) {
			action = SMFIS_DISCARD;
			message[0] = 0;
			log(LOG_DEBUG, "parse_line: discard");
		} else if (!strcasecmp(cmd, "reject")) {
			action = SMFIS_REJECT;
			if (parse_str(&s, message, sizeof(message)))
				strlcpy(message, DEFAULT_REJECT_MSG,
				    sizeof(message));
			log(LOG_DEBUG, "parse_line: reject %s", message);
		} else if (!strcasecmp(cmd, "tempfail")) {
			action = SMFIS_TEMPFAIL;
			if (parse_str(&s, message, sizeof(message)))
				strlcpy(message, DEFAULT_TEMPFAIL_MSG,
				    sizeof(message));
			log(LOG_DEBUG, "parse_line: tempfail %s", message);
		} else if (!strcasecmp(cmd, "connect")) {
			if (action == SMFIS_DISCARD) {
				log(LOG_ERR, "connect rule should not discard");
				return (1);
			}
			rules = &ruleset->connect;
			args = 2;
		} else if (!strcasecmp(cmd, "helo")) {
			if (action == SMFIS_DISCARD) {
				log(LOG_ERR, "helo rule should not discard");
				return (1);
			}
			rules = &ruleset->helo;
			args = 1;
		} else if (!strcasecmp(cmd, "envfrom")) {
			rules = &ruleset->envfrom;
			args = 1;
		} else if (!strcasecmp(cmd, "envrcpt")) {
			rules = &ruleset->envrcpt;
			args = 1;
		} else if (!strcasecmp(cmd, "header")) {
			rules = &ruleset->header;
			args = 2;
		} else if (!strcasecmp(cmd, "body")) {
			rules = &ruleset->body;
			args = 1;
		} else {
			log(LOG_ERR, "parse_line: invalid command '%s'", cmd);
			return (1);
		}
		if (rules != NULL) {
			struct rule *rule;
			int i;

			if ((rule = (struct rule *)malloc(sizeof(*rule))) ==
			    NULL) {
				log(LOG_ERR, "parse_line: malloc: %s",
				    strerror(errno));
				return (1);
			}
			memset(rule, 0, sizeof(*rule));
			rule->action = action;
			strlcpy(rule->message, message, sizeof(rule->message));
			for (i = 0; i < args; ++i)
				if (parse_arg(&s, &rule->arg[i])) {
					log(LOG_ERR, "parse_line: invalid "
					    "argument");
					return (1);
				}
			rule->next = *rules;
			*rules = rule;
			log_rule("parse_line", cmd, rule);
		}
	}
	return (0);
}

int
parse_file(const char *name, struct ruleset *ruleset)
{
	FILE *file;
	char line[8192];
	unsigned count = 0;

	if ((file = fopen(name, "r")) == NULL) {
		log(LOG_ERR, "parse_file: fopen: %s: %s", name, strerror(errno));
		return (1);
	}
	action = SMFIS_REJECT;
	strlcpy(message, DEFAULT_REJECT_MSG, sizeof(message));
	mutex_lock();
	while (fgets(line, sizeof(line), file) != NULL) {
		int l = strlen(line);

		if (l > 0 && line[l - 1] == '\n')
			line[l - 1] = 0;
		count++;
		if (parse_line(line, ruleset)) {
			mutex_unlock();
			fclose(file);
			log(LOG_ERR, "parse_file: error on line %u", count);
			return (1);
		}
	}
	mutex_unlock();
	fclose(file);
	return (0);
}

static void
free_rules(struct rule **rules)
{
	struct rule *rule = *rules;

	mutex_lock();
	while (rule != NULL) {
		struct rule *prev = rule;
		int i;

		for (i = 0; i < 2; ++i)
			if (rule->arg[i].s != NULL) {
				if (rule->arg[i].s[0])
					regfree(&rule->arg[i].re);
				free(rule->arg[i].s);
			}
		rule = rule->next;
		free(prev);
	}
	*rules = NULL;
	mutex_unlock();
}

void
free_ruleset(struct ruleset *ruleset)
{
	free_rules(&ruleset->connect);
	free_rules(&ruleset->helo);
	free_rules(&ruleset->envfrom);
	free_rules(&ruleset->envrcpt);
	free_rules(&ruleset->header);
	free_rules(&ruleset->body);
}

/* returns 0 on match, 1 on mismatch, -1 on error */
static int
match_rule(const struct rule *rule, int i, const char *s)
{
	int r;

	if (i < 0 || i > 1) {
		log(LOG_ERR, "match_rule: invalid argument");
		return (-1);
	}
	if (rule->arg[i].s == NULL) {
		log(LOG_ERR, "match_rule: invalid rule");
		return (-1);
	}
	if (!rule->arg[i].s[0])
		return (0);
	r = regexec(&rule->arg[i].re, s, 0, NULL, 0);
	if (r && r != REG_NOMATCH) {
		char err[8192];

		regerror(r, &rule->arg[i].re, err, sizeof(err));
		log(LOG_ERR, "match_rule: regexec: %s: %s", rule->arg[i].s, err);
		return (-1);
	}
	return ((r != REG_NOMATCH) == rule->arg[i].n);
}

const struct rule *
match_rules(const struct rule *rules, const char *s, const char *t)
{
	const struct rule *rule;

	if (rules == NULL || s == NULL)
		return (NULL);
	mutex_lock();
	for (rule = rules; rule != NULL; rule = rule->next) {
		int r;

		r = match_rule(rule, 0, s);
		if (r < 0) {
			mutex_unlock();
			return (NULL);
		}
		if (r > 0)
			continue;
		if (t != NULL) {
			r = match_rule(rule, 1, t);
			if (r < 0) {
				mutex_unlock();
				return (NULL);
			}
			if (r > 0)
				continue;
		}
		mutex_unlock();
		return (rule);
	}
	mutex_unlock();
	return (NULL);
}
