/* 

                          Firewall Builder

                 Copyright (C) 2002 NetCitadel, LLC

  Author:  Vadim Kurland     vadim@vk.crocodile.org

  $Id: PolicyCompiler_pf_writers.cpp,v 1.9 2004/09/19 06:20:12 vkurland Exp $

  This program is free software which we release under the GNU General Public
  License. You may redistribute and/or modify this program under the terms
  of that license as published by the Free Software Foundation; either
  version 2 of the License, or (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
 
  To get a copy of the GNU General Public License, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

*/

#include "PolicyCompiler_pf.h"

#include "fwbuilder/AddressRange.h"
#include "fwbuilder/RuleElement.h"
#include "fwbuilder/IPService.h"
#include "fwbuilder/ICMPService.h"
#include "fwbuilder/TCPService.h"
#include "fwbuilder/UDPService.h"
#include "fwbuilder/CustomService.h"
#include "fwbuilder/Policy.h"
#include "fwbuilder/FWOptions.h"
#include "fwbuilder/FWObjectDatabase.h"
#include "fwbuilder/RuleElement.h"
#include "fwbuilder/Interface.h"
#include "fwbuilder/IPv4.h"

#include <iostream>
#if __GNUC__ > 3 || \
    (__GNUC__ == 3 && (__GNUC_MINOR__ > 2 || (__GNUC_MINOR__ == 2 ) ) ) || \
    _MSC_VER
#  include <streambuf>
#else
#  include <streambuf.h>
#endif
#include <iomanip>
#include <fstream>
#include <sstream>

#include <assert.h>

using namespace libfwbuilder;
using namespace fwcompiler;
using namespace std;





/**
 *-----------------------------------------------------------------------
 *                    Methods for printing
 */
void PolicyCompiler_pf::PrintRule::_printAction(PolicyRule *rule)
{
    FWOptions *ruleopt =rule->getOptionsObject();
    Service *srv=compiler->getFirstSrv(rule);    assert(srv);

    switch (rule->getAction()) {
    case PolicyRule::Accept:  
    case PolicyRule::Accounting:  compiler->output << "pass "; break;
    case PolicyRule::Deny:        compiler->output << "block "; break;
    case PolicyRule::Reject: 
	if (TCPService::isA(srv)) compiler->output << "block return-rst ";
	else {
	    string aor=ruleopt->getStr("action_on_reject");
	    string code;
	    if ( aor.find("ICMP")!=string::npos ) {
		code="return-icmp ";
		if (aor.find("unreachable")!=string::npos ) {
 		    if (aor.find("net")!=string::npos)      code=code+"( 0 ) ";
		    if (aor.find("host")!=string::npos)     code=code+"( 1 ) ";
		    if (aor.find("protocol")!=string::npos) code=code+"( 2 ) ";
		    if (aor.find("port")!=string::npos)     code=code+"( 3 ) ";
		}
		if (aor.find("prohibited")!=string::npos ) {
 		    if (aor.find("net")!=string::npos)      code=code+"( 9 ) ";
		    if (aor.find("host")!=string::npos)     code=code+"( 10 ) ";
		}
	    } else
		code="return-icmp   ";

	    compiler->output << "block " << code;

	}
	break;
    case PolicyRule::Scrub:   compiler->output << "scrub   "; break;
    default:		      compiler->output << rule->getActionAsString() << " ";
    }
}

void PolicyCompiler_pf::PrintRule::_printDirection(PolicyRule *rule)
{
    if (rule->getDirection()==PolicyRule::Outbound)  compiler->output << "out "; 
    else	                                     compiler->output << "in  "; 
}

void PolicyCompiler_pf::PrintRule::_printLogging(PolicyRule *rule)
{
    if (rule->getLogging()) compiler->output << " log ";
}

void PolicyCompiler_pf::PrintRule::_printLabel(PolicyRule *rule)
{
    FWOptions *ruleopt =rule->getOptionsObject();
    string s=ruleopt->getStr("log_prefix");
    if (s.empty())  s=compiler->getCachedFwOpt()->getStr("log_prefix");
    if (!s.empty())
        compiler->output << " label " << _printLogPrefix(rule,s) << " ";
}

string PolicyCompiler_pf::PrintRule::_printLogPrefix(PolicyRule *rule,
                                                     const string &prefix)
{
    string s=prefix;

/* deal with our logging macros:
 * %N - rule number
 * %A - action
 * %I - interface name
 * %C - chain name
 */
    string::size_type n;
    if (rule && (n=s.find("%N"))!=string::npos ) {
        std::ostringstream s1;
        s1 << rule->getPosition();
        s.replace(n,2,s1.str());
    }
    if (rule && (n=s.find("%A"))!=string::npos ) {
        std::ostringstream s1;
        switch (rule->getAction()) {
        case PolicyRule::Accept:  s1 << "ACCEPT"; break;
        case PolicyRule::Deny:    s1 << "DROP";   break;
        case PolicyRule::Reject:  s1 << "REJECT"; break;
        case PolicyRule::Return:  s1 << "RETURN"; break;
        default: break;
        }
        s.replace(n,2,s1.str());
    }
    if (rule && (n=s.find("%I"))!=string::npos ) {
        std::ostringstream s1;
        string rule_iface =  rule->getInterfaceStr();
        if (rule_iface!="") 
        {
            s1 << rule_iface;
            s.replace(n,2,s1.str());
        } else
            s.replace(n,2,"global");
    }
    if (rule && (n=s.find("%C"))!=string::npos ) {
        s.replace(n,2,rule->getStr("ipt_chain"));
    }

    return "\"" + s + "\" ";
}


void PolicyCompiler_pf::PrintRule::_printInterface(PolicyRule *rule)
{
    string       iface_name = rule->getInterfaceStr();
    if (iface_name!="") 
	compiler->output << "on " << iface_name << " ";
}

void PolicyCompiler_pf::PrintRule::_printProtocol(libfwbuilder::Service *srv)
{

    if ( ! srv->isAny() && ! CustomService::isA(srv) && srv->getProtocolName()!="ip") {
	compiler->output << "proto ";
	compiler->output << srv->getProtocolName();
	compiler->output << " ";
    }
}

string PolicyCompiler_pf::PrintRule::_printPort(int rs,int re,bool neg)
{
    ostringstream  str;

    if (rs<0) rs=0;
    if (re<0) re=0;

    if (!neg) {

	if (rs>0 || re>0)
        {
            if (rs>re && re==0) re=rs;

	    if (rs==re)  str << rs;   // TODO: do we need '=' here ?
	    else
		if (rs==0 && re!=0)      str << "<= " << re;
		else
		    if (rs!=0 && re==65535)  str << ">= " << rs;
		    else {
/* 
 * port range. Operator '><' defines range in a such way that boundaries
 * are not included. Since we assume it is inclusive, let's move boundaries 
 */
			if (rs>0    ) rs--;
			if (re<65535) re++;
			str << rs << " >< " << re;
		    }
	}
    } else {

	if (rs>0 || re>0) {
	    if (rs==re)  str << "!= " << rs;
	    else
		if (rs==0 && re!=0)      str << "> " << re;
		else
		    if (rs!=0 && re==65535)  str << "< " << rs;
		    else {
			str << rs << " <> " << re;
		    }
	}

    }
    return str.str();
}

/*
 * we made sure that all services in rel  represent the same protocol
 */
void PolicyCompiler_pf::PrintRule::_printSrcService(RuleElementSrv  *rel)
{
/* I do not want to use rel->getFirst because it traverses the tree to
 * find the object. I'd rather use a cached copy in the compiler
 */
    FWObject *o=rel->front();
    if (o && FWReference::cast(o)!=NULL)
        o=compiler->getCachedObj( FWReference::cast(o)->getPointerId() );

    Service *srv= Service::cast(o);

    if (rel->size()==1) {
	if (UDPService::isA(srv) || TCPService::isA(srv)) {
	    string str=_printSrcService( srv , rel->getNeg());
	    if (! str.empty() ) compiler->output << "port " << str << " ";
	}
    } else {

	string str;
	for (FWObject::iterator i=rel->begin(); i!=rel->end(); i++) {
	    FWObject *o= *i;
//	    if (FWReference::cast(o)!=NULL) o=FWReference::cast(o)->getPointer();
	    if (FWReference::cast(o)!=NULL) o=compiler->getCachedObj(o->getStr("ref"));
	    Service *s=Service::cast( o );
	    assert(s);
	    if (UDPService::isA(srv) || TCPService::isA(srv)) {
		string str1= _printSrcService(s , rel->getNeg() );
		if (! str.empty() && ! str1.empty() )  str = str + ", ";
		str = str + str1;
	    }
	}
	if ( !str.empty() ) {
	    compiler->output << "port { " << str << "} ";
	}
    }
}

string PolicyCompiler_pf::PrintRule::_printSrcService(Service *srv,bool neg)
{
    ostringstream  str;
    if (TCPService::isA(srv) || UDPService::isA(srv)) 
    {
	int rs=srv->getInt("src_range_start");
	int re=srv->getInt("src_range_end");
	str << _printPort(rs,re,neg);
    }
    return str.str();
}

void PolicyCompiler_pf::PrintRule::_printDstService(RuleElementSrv  *rel)
{
    FWObject *o=rel->front();
    if (o && FWReference::cast(o)!=NULL)
        o=compiler->getCachedObj( FWReference::cast(o)->getPointerId() );

    Service *srv= Service::cast(o);


    if (rel->size()==1) 
    {
	string str=_printDstService( srv , rel->getNeg());
	if ( ! str.empty() ) 
        {
	    if (UDPService::isA(srv) || TCPService::isA(srv)) 
		compiler->output << "port " << str << " ";
            else
            {
                if (ICMPService::isA(srv)) 
                    compiler->output << "icmp-type " << str << " ";
                else
                    compiler->output << str << " ";
            }
	}
	if (TCPService::isA(srv)) 
        {
	    str=_printTCPFlags(TCPService::cast(srv));
	    if (!str.empty()) compiler->output << "flags " << str << " ";
	}
        if (IPService::isA(srv) && (srv->getBool("fragm") || srv->getBool("short_fragm")) )
                compiler->output << " fragment ";

    } else 
    {
	string str;
	for (FWObject::iterator i=rel->begin(); i!=rel->end(); i++) 
        {
	    FWObject *o= *i;
//	    if (FWReference::cast(o)!=NULL) o=FWReference::cast(o)->getPointer();
	    if (FWReference::cast(o)!=NULL) o=compiler->getCachedObj(o->getStr("ref"));
	    Service *s=Service::cast( o );
	    assert(s);
	    string str1= _printDstService(s , rel->getNeg() );
	    if (! str.empty() && ! str1.empty() )  str = str + ", ";
	    str = str + str1;
	}
	if ( !str.empty() ) 
        {
	    if (UDPService::isA(srv) || TCPService::isA(srv)) 
		compiler->output << "port { " << str << " } ";
            else
            {
                if (ICMPService::isA(srv)) 
                    compiler->output << "icmp-type { " << str << " } ";
                else
                    compiler->output << str << " " << endl;
            }
	}
    }
}

string PolicyCompiler_pf::PrintRule::_printDstService(Service *srv,bool neg)
{
    ostringstream  str;
    if (TCPService::isA(srv) || UDPService::isA(srv)) 
    {
	int rs=srv->getInt("dst_range_start");
	int re=srv->getInt("dst_range_end");
	str << _printPort(rs,re,neg);
    }

    if (ICMPService::isA(srv) && srv->getInt("type")!=-1) 
    {
	str << srv->getStr("type") << " ";
	if (srv->getInt("code")!=-1) 
	    str << "code " << srv->getStr("code") << " ";
    }

    if (CustomService::isA(srv)) 
    {
	str << CustomService::cast(srv)->getCodeForPlatform( compiler->myPlatformName() ) << " ";
    }

    return str.str();
}

string PolicyCompiler_pf::PrintRule::_printTCPFlags(libfwbuilder::TCPService *srv)
{
    string str;
    if (srv->inspectFlags())
    {
        if (srv->getTCPFlag(TCPService::URG)) str+="U";
        if (srv->getTCPFlag(TCPService::ACK)) str+="A";
        if (srv->getTCPFlag(TCPService::PSH)) str+="P";
        if (srv->getTCPFlag(TCPService::RST)) str+="R";
        if (srv->getTCPFlag(TCPService::SYN)) str+="S";
        if (srv->getTCPFlag(TCPService::FIN)) str+="F";
        str+="/";
        if (srv->getTCPFlagMask(TCPService::URG)) str+="U";
        if (srv->getTCPFlagMask(TCPService::ACK)) str+="A";
        if (srv->getTCPFlagMask(TCPService::PSH)) str+="P";
        if (srv->getTCPFlagMask(TCPService::RST)) str+="R";
        if (srv->getTCPFlagMask(TCPService::SYN)) str+="S";
        if (srv->getTCPFlagMask(TCPService::FIN)) str+="F";
    }
    return str;
}

void PolicyCompiler_pf::PrintRule::_printAddr(Address  *o,bool neg)
{
    IPAddress addr=o->getAddress();
    Netmask   mask=o->getNetmask();

    if (Interface::cast(o)!=NULL)
    {
	Interface *interface_=Interface::cast(o);
	if (interface_->isDyn()) 
        {
	    compiler->output << "(" << interface_->getName() << ") ";
	    return;
	}

	mask=Netmask("255.255.255.255");
    }

    if (IPv4::cast(o)!=NULL) 
    {
	mask=Netmask("255.255.255.255");
    }

    if (addr.toString()=="0.0.0.0" && mask.toString()=="0.0.0.0") 
    {
	compiler->output << "any ";
    } else 
    {
//	if (neg) compiler->output << "! ";
	compiler->output << addr.toString();
	if (mask.toString()!="255.255.255.255")
        {
	    compiler->output << "/" << mask.getLength();
	}
	compiler->output << " ";
    }
}

void PolicyCompiler_pf::PrintRule::_printAddrList(FWObject  *grp,bool negflag)
{
    compiler->output << "{ ";
    for (FWObject::iterator i=grp->begin(); i!=grp->end(); i++)
    {
        if (i!=grp->begin())  compiler->output << ", ";
        FWObject *o= *i;
        if (FWReference::cast(o)!=NULL) o=compiler->getCachedObj(o->getStr("ref"));
        Address *s=Address::cast( o );
        assert(s);
        _printAddr(s , negflag);
    }
    compiler->output << "} ";
}

void PolicyCompiler_pf::PrintRule::_printSrcAddr(RuleElementSrc  *rel)
{
    FWObject *o=rel->front();
    if (o && FWReference::cast(o)!=NULL)
        o=compiler->getCachedObj( FWReference::cast(o)->getPointerId() );

    Address *src= Address::cast(o);

    _printNegation(rel);

    if (rel->size()==1 && ! o->getBool("pf_table") )
    {
	_printAddr( src , rel->getNeg() );
    } else
    {
        if (o->getBool("pf_table"))
        {
            compiler->output << "<" << o->getName() << "> ";
        } else
        {
            _printAddrList(rel,rel->getNeg());
        }
    }

}

void PolicyCompiler_pf::PrintRule::_printDstAddr(RuleElementDst  *rel)
{
    FWObject *o=rel->front();
    if (o && FWReference::cast(o)!=NULL)
        o=compiler->getCachedObj( FWReference::cast(o)->getPointerId() );

    Address *dst= Address::cast(o);

    _printNegation(rel);

    if (rel->size()==1 && ! o->getBool("pf_table") )
    {
	_printAddr( dst , rel->getNeg());
    } else
    {
        if (o->getBool("pf_table"))
        {
            compiler->output << "<" << o->getName() << "> ";
        } else
        {
            _printAddrList(rel,rel->getNeg());
        }
    }
}

void PolicyCompiler_pf::PrintRule::_printNegation(libfwbuilder::RuleElement  *rel)
{
    if (rel->getNeg())
	compiler->output << "! ";
}


PolicyCompiler_pf::PrintRule::PrintRule(const std::string &name) : PolicyRuleProcessor(name) 
{ 
    init=true; 
}

bool PolicyCompiler_pf::PrintRule::processNext()
{
    PolicyRule *rule=getNext(); if (rule==NULL) return false;
    FWOptions  *ruleopt =rule->getOptionsObject();

    tmp_queue.push_back(rule);

    string rl=rule->getLabel();
    if (rl!=current_rule_label)
    {
        
        compiler->output << "# " << endl;
        compiler->output << "# Rule  " << rl << endl;

        string    comm=rule->getComment();
        string::size_type c1,c2;
        c1=0;
        while ( (c2=comm.find('\n',c1))!=string::npos ) {
            compiler->output << "# " << comm.substr(c1,c2-c1) << endl;
            c1=c2+1;
        }
        compiler->output << "# " << comm.substr(c1) << endl;
        compiler->output << "# " << endl;

        current_rule_label=rl;
    }




    RuleElementSrc *srcrel=rule->getSrc();
//    Address        *src   =compiler->getFirstSrc(rule);  assert(src);
    RuleElementDst *dstrel=rule->getDst();
//    Address        *dst   =compiler->getFirstDst(rule);  assert(dst);
    RuleElementSrv *srvrel=rule->getSrv();
    Service        *srv   =compiler->getFirstSrv(rule);  assert(srv);

    _printAction(rule);
    _printDirection(rule);
    _printLogging(rule);
    if ( rule->getBool("quick") ) compiler->output << " quick ";
    _printInterface(rule);

    compiler->output << "inet ";

    _printProtocol(srv);

    compiler->output << " from ";
    _printSrcAddr(srcrel);
    _printSrcService(srvrel);

    compiler->output << " to ";
    _printDstAddr(dstrel);
    _printDstService(srvrel);

/* 
 * Dealing with "keep state" and "modulate state" flags
 *
 * 1. both flags do not apply to deny/reject rules.
 * 2. modulate state applies only to TCP services. Since we use splitServices,
 *    all services in a rule are of the same protocol, therefore we can simply
 *    check type of srv
 */
    if ( ! ruleopt->getBool("stateless") && rule->getAction()==PolicyRule::Accept)
    {

        TCPService *tcpsrv=TCPService::cast(srv);

        if ( ! compiler->getCachedFwOpt()->getBool("accept_new_tcp_with_no_syn") &&
             tcpsrv!=NULL && !tcpsrv->inspectFlags() )
            compiler->output << "flags S/SA ";

	if (compiler->getCachedFwOpt()->getBool("modulate_state") && tcpsrv!=NULL)
	    compiler->output << "modulate state ";
	else
	    compiler->output << "keep state ";

        int nopt=0;
        if (ruleopt->getInt("pf_rule_max_state")>0) nopt++;
        if (ruleopt->getBool("pf_source_tracking")) nopt+=2;

        if (nopt)
        {
            if (nopt>1) compiler->output << " ( ";

            if (ruleopt->getInt("pf_rule_max_state")>0)
            {
                compiler->output << " max " << ruleopt->getInt("pf_rule_max_state");
                if (nopt>1) compiler->output << ",";
                else        compiler->output << " ";
            }

            if (ruleopt->getBool("pf_source_tracking"))
            {
                compiler->output << " max-src-nodes "
                                 << ruleopt->getInt("pf_max_src_nodes") << ",";
                compiler->output << " max-src-states "
                                 << ruleopt->getInt("pf_max_src_states");
            }
            if (nopt>1) compiler->output << " ) ";
        }
    }

    if (rule->getBool("allow_opts")) compiler->output << "allow-opts  ";

    _printLabel(rule);

    compiler->output << endl;

    return true;
}

bool PolicyCompiler_pf::PrintTables::processNext()
{
    PolicyCompiler_pf *pf_comp=dynamic_cast<PolicyCompiler_pf*>(compiler);

    slurp();
    if (tmp_queue.size()==0) return false;

/* print tables */

    compiler->output << endl;
    compiler->output << endl;
    compiler->output << "# Tables: (" << pf_comp->tables.size() << ")" << endl;
    for (FWObject::iterator i=pf_comp->tables.begin(); i!=pf_comp->tables.end(); i++)
    {
        FWObject *grp=*i;
        compiler->output << "table <" << grp->getName() << "> ";

        compiler->output << "{ ";
        for (FWObject::iterator i=grp->begin(); i!=grp->end(); i++)
        {
            if (i!=grp->begin())  compiler->output << ", ";
            FWObject *o= *i;
            if (FWReference::cast(o)!=NULL) o=compiler->getCachedObj(o->getStr("ref"));
            if (Interface::cast(o))
            {
                compiler->output << o->getName();
            } else
            {
                Address *A=Address::cast( o );
                assert(A);

                IPAddress addr=A->getAddress();
                Netmask   mask=A->getNetmask();

                if (IPv4::cast(A)!=NULL) {
                    mask=Netmask("255.255.255.255");
                }

                compiler->output << addr.toString();
                if (mask.toString()!="255.255.255.255") {
                    compiler->output << "/" << mask.getLength();
                }
            }
            compiler->output << " ";
        }
        compiler->output << "} ";
        compiler->output << endl;
    }
    compiler->output << endl;

    return true;
}
