/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2004  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the codeied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include <climits>
#include <cstring>
#include <cstdlib>
#include <limits>
#include <list>
#include <stack>
#include <map>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <sstream>
#include <cassert>
#include "lexgen.h"

template<typename T>
std::string num2str(T val)
{
	std::ostringstream strm;
	strm << val;
	return strm.str();
}

LexGen::LexGen(const char* core_class, const char* subclass)
:	m_core_class(core_class),
	m_subclass(subclass),
	m_trackStreamPosition(false),
	m_trackLineCol(false),
	m_rule_id_generator(0)
{
}

LexGen::Rule::Rule(int id, int scond, const std::string& action, const Nfa& nfa)
:	m_id(id), m_scond(scond), m_action(action), m_nfa(nfa), m_pTrailer(0)
{
}

LexGen::Rule::Rule(int id, int scond, const std::string& action, const Nfa& nfa, const Nfa& trailer)
:	m_id(id), m_scond(scond), m_action(action), m_nfa(nfa), m_pTrailer(new NfaConcatenation(trailer))
{
}

LexGen::Rule::Rule(const Rule& other)
:	m_id(other.m_id),
	m_scond(other.m_scond),
	m_action(other.m_action),
	m_nfa(other.m_nfa),
	m_pTrailer(other.m_pTrailer ? new NfaConcatenation(*other.m_pTrailer) : 0),
	m_options(other.m_options)
{
}

LexGen::Rule::~Rule()
{
	if (m_pTrailer)
		delete m_pTrailer;
}

LexGen::Rule&
LexGen::Rule::operator=(const Rule& other)
{
	if (this == &other)
		return *this;
	
	m_id = other.m_id;
	m_scond = other.m_scond;
	m_action = other.m_action;
	m_options = other.m_options;
	m_nfa = other.m_nfa;
	if (m_pTrailer)
		delete m_pTrailer;
	m_pTrailer = other.m_pTrailer ? new NfaConcatenation(*other.m_pTrailer) : 0;
	return *this;
}

LexGen::Options&
LexGen::addRule(int scond, const Nfa& nfa, const std::string& action)
{
	Rule rule(m_rule_id_generator++, scond, action, nfa);
	return m_rules_by_id.insert(rules_by_id_type::value_type(rule.m_id, rule)).first->second.m_options;
}

LexGen::Options&
LexGen::addRule(int scond, const Nfa& nfa, const Nfa& trailer, const std::string& action)
{
	Rule rule(m_rule_id_generator++, scond, action, nfa, trailer);
	return m_rules_by_id.insert(rules_by_id_type::value_type(rule.m_id, rule)).first->second.m_options;
}

namespace {
	template<typename T> struct DiffersFrom
	{
		DiffersFrom(const T& val) : value(val) {}
		bool operator()(const T& val) {
			return val != value;
		}
		T value;
	};
}

void
LexGen::writeLexer(std::ostream& header, std::ostream& impl,
	const char* def_class, const char* def_header, const char* subclass_header)
{
	static const char* header_skel[] =
#	include "lexer_h.inc"

	static const char* impl_skel[] =
#	include "lexer_cpp.inc"
	
	static const char* action_skel[] =
#	include "action.inc"
	
	if (m_rules_by_id.empty()) {
		std::cerr << "lexgen: nothing to do" << std::endl;
		exit(EXIT_FAILURE);
	}
	for (rules_by_id_type::iterator iter = m_rules_by_id.begin(); iter != m_rules_by_id.end(); ++iter) {
		m_rules_by_scond[iter->second.m_scond].push_back(&iter->second);
	}
	int min_scond = INT_MAX;
	int max_scond = INT_MIN;
	rules_by_scond_type::const_iterator iter = m_rules_by_scond.begin();
	for (; iter != m_rules_by_scond.end(); ++iter) {
		const int scond = iter->first;
		if (scond <= min_scond)
			min_scond = scond;
		if (scond >= max_scond)
			max_scond = scond;
	}
	if (min_scond != 0) {
		std::cerr << "lexgen: Start conditions must start with zero" << std::endl;
		exit(EXIT_FAILURE);
	}
	for (int i = min_scond; i <= max_scond; i++) {
		if ((iter = m_rules_by_scond.find(i)) == m_rules_by_scond.end()) {
			std::cerr << "lexgen: Start condtions must not have gaps" << std::endl;
			exit(EXIT_FAILURE);
		}
	}
	
	int n_rules = m_rules_by_id.size();
	
	std::vector<std::list<DState> > dstates_by_scond(max_scond+1, std::list<DState>());
	
	int total_dstates = 0;
	iter = m_rules_by_scond.begin();
	for (; iter != m_rules_by_scond.end(); ++iter) {
		const int scond = iter->first;
		LexerRuleComposition composition(iter->second);
		std::list<DState>& dstate_list = dstates_by_scond[scond];
		DState first_dstate(scond);
		first_dstate.m_stateset = eClosure(&composition.m_states[0], composition.getStartPos());
		dstate_list.push_back(first_dstate);
		++total_dstates;
		for (std::list<DState>::iterator cur_dstate = dstate_list.begin();
		      cur_dstate != dstate_list.end(); ++cur_dstate) {
			int finish_rule = INT_MAX;
			bool is_quickfinish = false;
			bool is_lazy_finish = false;
			std::set<int>::const_iterator iter2 = cur_dstate->m_stateset.begin();
			for (; iter2 != cur_dstate->m_stateset.end(); ++iter2) {
				if (*iter2 < 0)
					continue;
				
				const NStateInfo* ns_info = &composition.m_stateInfos[*iter2];
				const Rule* rule = &m_rules_by_id.find(ns_info->m_rule)->second;
				
				if (ns_info->m_isSemiFinish) {
					cur_dstate->m_semiFinishFor.insert(ns_info->m_rule);
				}
				
				if (ns_info->m_isQuickFinish || ns_info->m_isLongFinish) {
					bool set_rule = false;
					if (ns_info->m_rule < finish_rule) {
						set_rule = (rule->m_options.isLazy() || !is_lazy_finish);
					} else {
						set_rule = (rule->m_options.isLazy() && !is_lazy_finish);
					}
					if (set_rule) {
						finish_rule = ns_info->m_rule;
						is_quickfinish = ns_info->m_isQuickFinish;
						is_lazy_finish = rule->m_options.isLazy();
					}
				}
			}
			if (finish_rule < INT_MAX) {
				cur_dstate->m_finishFor = finish_rule;
				cur_dstate->m_isQuickFinish = is_quickfinish;
			}
			
			cur_dstate->m_isDeadEnd = true;
			for (unsigned int ch = 0; ch < 256; ch++) {
				std::set<int> target_states = eClosure(&composition.m_states[0],
					move(&composition.m_states[0], cur_dstate->m_stateset, ch));
				if (target_states.empty()) {
					cur_dstate->m_moves[ch] = 0;
					continue;
				}
				cur_dstate->m_isDeadEnd = false;
				std::list<DState>::iterator fit = std::find(dstate_list.begin(), dstate_list.end(), target_states);
				if (fit != dstate_list.end()) {
					cur_dstate->m_moves[ch] = &*fit;
				} else {
					dstate_list.push_back(DState(scond));
					dstate_list.back().m_stateset = target_states;
					++total_dstates;
					cur_dstate->m_moves[ch] = &dstate_list.back();
				}
				checkForDangerousTrailingContexts(*cur_dstate, *cur_dstate->m_moves[ch], composition);
			}
		}
	}
	
	std::vector<int> start_state_by_scond(max_scond+1, 0);
	std::vector<const DState*> dstates(total_dstates, static_cast<const DState*>(0));
	std::map<const DState*, int> dstate_num_by_ptr;
	int normal_insert_pos = 0;
	int deadend_insert_pos = total_dstates;
	for (int i = 0; i <= max_scond; ++i) {
		std::list<DState>::iterator iter1 = dstates_by_scond[i].begin();
		for (; iter1 != dstates_by_scond[i].end(); ++iter1) {
			int pos;
			if (iter1->m_isDeadEnd) {
				pos = --deadend_insert_pos;
			} else {
				pos = normal_insert_pos++;
			}
			dstates[pos] = &*iter1;
			dstate_num_by_ptr[dstates[pos]] = pos;
			if (iter1 == dstates_by_scond[i].begin()) {
				start_state_by_scond[i] = pos;
			}
		}
	}
	int num_non_deadend_states = normal_insert_pos;
	
	std::vector<unsigned int> stateprop_table;
	stateprop_table.reserve(total_dstates);
	std::vector<int> stateinfo_table;
	
	for (int i = 0; i < total_dstates; i++) {
		/*
		stateprop layout:
		bits 0..2:
		  000: greedy long finish
		  001: lazy long finish
		  010: lazy-on-block-end long finish
		  100: greedy quick finish
		  101: lazy quick finish
		  110: lazy-on-block-end quick finish
		  111: no complete finishes
		bit 3: semifinish(es)
		the rest of bits: offset in stateinfo table
		*/
		const DState* dstate = dstates[i];
		unsigned int stateprop = 0;
		unsigned int off = stateinfo_table.size();
		if (dstate->m_finishFor == -1) {
			stateprop |= 7;
		} else {
			const Rule* rule = &m_rules_by_id.find(dstate->m_finishFor)->second;
			stateinfo_table.push_back(dstate->m_finishFor);
			stateprop |= off << 4;
			if (dstate->m_isQuickFinish) {
				stateprop |= 4;
			}
			if (rule->m_options.isLazy()) {
				stateprop |= 1;
			} else if (rule->m_options.isLazyOnBlockEnd()) {
				stateprop |= 2;
			}
		}
		if (!dstate->m_semiFinishFor.empty()) {
			std::set<int>::const_iterator iter = dstate->m_semiFinishFor.begin();
			for (; iter != dstate->m_semiFinishFor.end(); ++iter) {
				stateinfo_table.push_back(*iter);
			}
			stateinfo_table.push_back(-1);
			stateprop |= 8;
			stateprop |= off << 4;
		}
		stateprop_table.push_back(stateprop);
	}
	
	std::vector<int> symbol_transition_index;
	symbol_transition_index.reserve(num_non_deadend_states*16);
	std::vector<int> symbol_transition_table;
	std::vector<int> eof_transition_table;
	eof_transition_table.reserve(num_non_deadend_states);
	
	for (int i = 0; i < num_non_deadend_states; ++i) {
		DState* const* moves = dstates[i]->m_moves;
		for (int row = 0; row < 16; ++row) {
			int transition_row[16];
			for (int off = 0; off < 16; ++off) {
				DState* target = moves[(row<<4)+off];
				if (target) {
					transition_row[off] = dstate_num_by_ptr[target];
				} else {
					transition_row[off] = -1;
				}
			}
			if (std::find_if(transition_row, transition_row+16,
			    DiffersFrom<int>(transition_row[0])) == transition_row+16) {
				symbol_transition_index.push_back(-transition_row[0]-2);
			} else {
				bool found = false;
				for (unsigned int i = 0; i < symbol_transition_table.size(); i += 16) {
					if (!std::memcmp(&symbol_transition_table[i], transition_row, sizeof(transition_row))) {
						found = true;
						symbol_transition_index.push_back(i>>4);
						break;
					}
				}
				if (!found) {
					symbol_transition_index.push_back(symbol_transition_table.size()>>4);
					symbol_transition_table.insert(symbol_transition_table.end(), transition_row, transition_row+16);
				}
			}
		}
	}
	
	const char* transition_type = getTypeForRange(-total_dstates - 1, total_dstates - 1);
	const char* transition_index_type = getTypeForRange(
		-total_dstates - 1,
		std::max<int>(total_dstates - 1, (symbol_transition_table.size() >> 4) - 1)
	);
	
	std::map<std::string, std::string> substitutions;
	substitutions["@CORE_CLASS@"] = m_core_class;
	substitutions["@SUBCLASS@"] = m_subclass;
	substitutions["@SUBCLASS_HEADER@"] = subclass_header;
	substitutions["@DEF_CLASS@"] = def_class;
	substitutions["@DEF_HEADER@"] = def_header;
	substitutions["@NUM_RULES@"] = num2str(m_rules_by_id.size());
	substitutions["@NUM_NON_DEADEND_STATES@"] = num2str(num_non_deadend_states);
	substitutions["@TRANSITION_TYPE@"] = transition_type;
	substitutions["@TRANSITION_INDEX_TYPE@"] = transition_index_type;
	substitutions["@TRACK_STREAM_POSITION@"] = m_trackStreamPosition ? "1" : "0";
	substitutions["@TRACK_LINE_COL@"] = m_trackLineCol ? "1" : "0";
	substitutions["@TRACKING_ENABLED@"] = (m_trackStreamPosition || m_trackLineCol) ? "1" : "0";
	
	writeSkel(header, header_skel, substitutions);
	writeSkel(impl, impl_skel, substitutions);
	
	impl << std::endl << m_included_code << std::endl;
	
	for (rules_by_id_type::iterator iter = m_rules_by_id.begin(); iter != m_rules_by_id.end(); ++iter) {
		substitutions["@RULE_ID@"] = num2str(iter->first);
		substitutions["@ACTION_CODE@"] = iter->second.m_action;
		writeSkel(impl, action_skel, substitutions);
		impl << std::endl;
	}
	
	impl << "const " << m_core_class << "::ActionPtr " << m_core_class << "::m_actionTable[] = {";
	for (int i = 0; i < n_rules; ++i) {
		if (i != 0)
			impl << ",";
		impl << std::endl << "\t&" << m_core_class << "::Actions::action<" << i << ">";
	}
	impl << std::endl << "};" << std::endl;
	
	impl << "const " << transition_index_type << " " << m_core_class << "::m_symbolTransitionIndex[] = {" << std::endl;
	for (unsigned int i = 0; i < symbol_transition_index.size();) {
		for (unsigned int j = 0; i < symbol_transition_index.size() && j < 16; ++i, ++j) {
			impl << symbol_transition_index[i];
			if (i != symbol_transition_index.size()-1) {
				impl << ", ";
			}
		}
		impl << std::endl;
	}
	impl << "};" << std::endl;
	
	impl << "const " << transition_type << " " << m_core_class << "::m_symbolTransitionTable[] = {" << std::endl;
	for (unsigned int i = 0; i < symbol_transition_table.size();) {
		for (unsigned int j = 0; i < symbol_transition_table.size() && j < 16; ++i, ++j) {
			impl << symbol_transition_table[i];
			if (i != symbol_transition_table.size()-1) {
				impl << ", ";
			}
		}
		impl << std::endl;
	}
	impl << "};" << std::endl;
	
	impl << "const " << transition_type << " " << m_core_class << "::m_stateStartTable[] = {" << std::endl;
	for (unsigned int i = 0; i < start_state_by_scond.size();) {
		for (unsigned int j = 0; i < start_state_by_scond.size() && j < 8; ++i, ++j) {
			impl << start_state_by_scond[i];
			if (i != start_state_by_scond.size()-1) {
				impl << ", ";
			}
		}
		impl << std::endl;
	}
	impl << "};" << std::endl;
	
	impl << "const uint32_t " << m_core_class << "::m_statePropTable[] = {" << std::endl;
	for (unsigned int i = 0; i < stateprop_table.size();) {
		for (unsigned int j = 0; i < stateprop_table.size() && j < 8; ++i, ++j) {
			impl << stateprop_table[i];
			if (i != stateprop_table.size()-1) {
				impl << ", ";
			}
		}
		impl << std::endl;
	}
	impl << "};" << std::endl;
	
	impl << "const " << transition_type << " " << m_core_class << "::m_stateInfoTable[] = {" << std::endl;
	for (unsigned int i = 0; i < stateinfo_table.size();) {
		for (unsigned int j = 0; i < stateinfo_table.size() && j < 16; ++i, ++j) {
			impl << stateinfo_table[i];
			if (i != stateinfo_table.size()-1) {
				impl << ", ";
			}
		}
		impl << std::endl;
	}
	impl << "};" << std::endl;
}

std::set<int>
LexGen::eClosure(const NfaState* states, int from)
{
	std::set<int> frm;
	frm.insert(from);
	return eClosure(states, frm);
}

std::set<int>
LexGen::eClosure(const NfaState* states, const std::set<int>& from)
{
	std::set<int> res;
	std::stack<int> nfa_states;
	std::set<int>::const_iterator it = from.begin();
	for (; it != from.end(); ++it) {
		nfa_states.push(*it);
		res.insert(*it);
	}
	while (!nfa_states.empty()) {
		int nstate = nfa_states.top();
		nfa_states.pop();
		const std::set<int>& emoves = states[nstate].getEpsilonTransitions();
		std::set<int>::const_iterator iter = emoves.begin();
		for (; iter != emoves.end(); ++iter) {
			int new_state = nstate+*iter;
			if (res.find(new_state) == res.end()) {
				res.insert(new_state);
				nfa_states.push(new_state);
			}
		}
	}
	return res;
}

std::set<int>
LexGen::move(const NfaState* states, const std::set<int>& from, unsigned char ch)
{
	std::set<int> res;
	std::set<int>::const_iterator it = from.begin();
	for (; it != from.end(); ++it) {
		int move = states[*it].getSymbolTransition(ch);
		if (move != NfaState::NO_TRANSITION) {
			res.insert(*it+move);
		}
	}
	return res;
}

LexGen::LexerRuleComposition::LexerRuleComposition(const std::list<Rule*>& rules)
:	m_states(1, NfaState()),
	m_stateInfos(1, NStateInfo(-1, false))
{
	for (std::list<Rule*>::const_iterator iter = rules.begin(); iter != rules.end(); ++iter) {
		const Rule* rule = *iter;
		int base = m_states.size();
		m_states.insert(m_states.end(), rule->m_nfa.statesBegin(), rule->m_nfa.statesEnd());
		m_stateInfos.insert(m_stateInfos.end(), rule->m_nfa.statesEnd()-rule->m_nfa.statesBegin(),
			NStateInfo(rule->m_id, false));
		m_states[0].addEpsilonTransition(base+rule->m_nfa.getStartPos());
		if (rule->m_pTrailer) {
			int trailer_base = m_states.size();
			m_states.insert(m_states.end(), rule->m_pTrailer->statesBegin(), rule->m_pTrailer->statesEnd());
			m_stateInfos.insert(m_stateInfos.end(), rule->m_pTrailer->statesEnd()-rule->m_pTrailer->statesBegin(),
				NStateInfo(rule->m_id, true));
			int semifinish = base+rule->m_nfa.getFinishPos();
			int trailer_start = trailer_base+rule->m_pTrailer->getStartPos();
			m_states[semifinish].addEpsilonTransition(trailer_start-semifinish);
			std::set<int> trailer_start_states = eClosure(&m_states[0], trailer_start);
			for (std::set<int>::iterator it = trailer_start_states.begin();
			     it != trailer_start_states.end(); ++it) {
				m_stateInfos[*it].m_isTrailer = false;
			}
			m_stateInfos[semifinish].m_isSemiFinish = true;
			int finish = trailer_base+rule->m_pTrailer->getFinishPos();
			m_stateInfos[finish].m_isLongFinish = true;
		} else {
			int finish = base+rule->m_nfa.getFinishPos();
			m_stateInfos[finish].m_isQuickFinish = true;
		}
	}
}

LexGen::DState::DState(int scond)
:	m_finishFor(-1), m_isQuickFinish(false),
	m_isDeadEnd(false), m_scond(scond)
{}

void
LexGen::checkForDangerousTrailingContexts(const DState& src_state,
	const DState& dst_state, const LexerRuleComposition& composition)
{
	std::set<int>::const_iterator iter1 = src_state.m_stateset.begin();
	for (; iter1 != src_state.m_stateset.end(); ++iter1) {
		std::set<int>::const_iterator iter2 = dst_state.m_stateset.begin();
		for (; iter2 != dst_state.m_stateset.end(); ++iter2) {
			const NStateInfo& info1 = composition.m_stateInfos[*iter1];
			const NStateInfo& info2 = composition.m_stateInfos[*iter2];
			if (info1.m_rule != info2.m_rule) {
				continue;
			}
			if (info1.m_isTrailer && info2.m_isSemiFinish) {
				std::cerr << "Error: dangerous trailing context for rule #" << (info1.m_rule+1)
					<< " (counting from 1)" << std::endl;
				exit(EXIT_FAILURE);
			}
		}
	}
}

const char*
LexGen::getTypeForRange(int min, int max)
{
	if (std::numeric_limits<int8_t>::min() <= min &&
	    std::numeric_limits<int8_t>::max() >= max) {
		return "int8_t";
	} else if (std::numeric_limits<int16_t>::min() <= min &&
	           std::numeric_limits<int16_t>::max() >= max) {
		return "int16_t";
	} else {
		return "int32_t";
	}
}

void
LexGen::writeSkel(std::ostream& out, const char** skel, const std::map<std::string, std::string>& substitutions)
{
	for (int i = 0; skel[i]; i++) {
		std::map<std::string, std::string>::const_iterator fit = substitutions.find(skel[i]);
		if (fit != substitutions.end()) {
			out << fit->second;
		} else {
			out << skel[i];
		}
	}
}

