/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2006  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "DnsCache.h"
#include "RefCountable.h"
#include "AtomicCounter.h"
#include "InetAddr.h"
#include "SymbolicInetAddr.h"
#include "TimeStamp.h"
#include "TimeDelta.h"
#include <ace/config-lite.h>
#include <ace/Synch.h>
#include <ace/Singleton.h>
#include <ace/OS_NS_sys_time.h>
#include <boost/multi_index_container.hpp>
#include <boost/multi_index/member.hpp>
#include <boost/multi_index/ordered_index.hpp>
#include <boost/multi_index/sequenced_index.hpp>
#include <iterator>
#include <utility>

using namespace std;
using namespace boost::multi_index;

struct DnsCache::Entry
{
	Entry(SymbolicInetAddr const& symbolic_addr,
		vector<InetAddr> const& resolved_addrs,
		TimeStamp const& timeout)
	:	symbolicAddr(symbolic_addr),
		resolvedAddrs(resolved_addrs),
		timeout(timeout)
	{
	}
	
	SymbolicInetAddr symbolicAddr;
	std::vector<InetAddr> resolvedAddrs;
	TimeStamp timeout; // absolute time value
};


class DnsCache::Impl : public DnsCache
{
public:
	Impl();
	
	virtual ~Impl();
	
	virtual void put(
		SymbolicInetAddr const& symbolic_addr,
		std::vector<InetAddr> const& resolved_addrs);
	
	virtual std::vector<InetAddr> get(
		SymbolicInetAddr const& symbolic_addr);
private:
	class SymbolicAddrTag {};
	class SequenceTag {};
	
	typedef ACE_Thread_Mutex Mutex;
	typedef multi_index_container<
		Entry,
		indexed_by<
			ordered_unique<
				tag<SymbolicAddrTag>,
				member<Entry, SymbolicInetAddr, &Entry::symbolicAddr>
			>,
			sequenced<
				tag<SequenceTag>
			>
		>
	> Container;
	typedef Container::index<SymbolicAddrTag>::type SymbolicAddrIdx;
	typedef Container::index<SequenceTag>::type SequenceIdx;
	
	void removeTimedOut();
	
	Mutex m_mutex;
	Container m_container;
};


DnsCache*
DnsCache::instance()
{
	return ACE_Singleton<Impl, ACE_Recursive_Thread_Mutex>::instance();
}


/*============================ DnsCache::Impl ===========================*/

DnsCache::Impl::Impl()
{
}

DnsCache::Impl::~Impl()
{
}

void
DnsCache::Impl::put(
	SymbolicInetAddr const& symbolic_addr,
	std::vector<InetAddr> const& resolved_addrs)
{
	if (resolved_addrs.empty()) {
		return;
	}
	
	ACE_GUARD_RETURN(Mutex, guard, m_mutex, );
	
	removeTimedOut();
	
	TimeStamp timeout = TimeStamp::fromTimeval(ACE_OS::gettimeofday());
	timeout += TimeDelta::fromSec(TIMEOUT);
	
	Entry entry(symbolic_addr, resolved_addrs, timeout);
	entry.symbolicAddr.setPort(80);
	// We keep ports the same, to avoid unnecessary lookups.
	
	SymbolicAddrIdx& addr_idx = m_container.get<SymbolicAddrTag>();
	SequenceIdx& seq_idx = m_container.get<SequenceTag>();
	
	SequenceIdx::iterator seq_pos;
	pair<SymbolicAddrIdx::iterator, SymbolicAddrIdx::iterator> addr_range(
		addr_idx.equal_range(entry.symbolicAddr)
	);
	if (addr_range.first != addr_range.second) {
		// replacing an entry with the same address
		addr_idx.replace(addr_range.first, entry);
		seq_pos = m_container.project<SequenceTag>(addr_range.first);
	} else if (m_container.size() >= CAPACITY) {
		// replacing the first entry in sequence (the oldest one)
		seq_pos = seq_idx.begin();
		seq_idx.replace(seq_pos, entry);
	} else {
		// inserting a new entry
		addr_idx.insert(addr_range.first, entry);
		return; // entry inserted at the end of seq_idx, no need to relocate it
	}
	
	seq_idx.relocate(seq_idx.end(), seq_pos);
}

std::vector<InetAddr>
DnsCache::Impl::get(
	SymbolicInetAddr const& symbolic_addr)
{
	vector<InetAddr> res;
	
	ACE_GUARD_RETURN(Mutex, guard, m_mutex, res);
	
	removeTimedOut();
	
	SymbolicInetAddr sym_addr(symbolic_addr);
	sym_addr.setPort(80); // see above
	
	SymbolicAddrIdx& idx = m_container.get<SymbolicAddrTag>();
	SymbolicAddrIdx::iterator it = idx.find(sym_addr);
	if (it != idx.end()) {
		res = it->resolvedAddrs;
		for (unsigned i = 0; i < res.size(); ++i) {
			res[i].set_port_number(symbolic_addr.getPort());
		}
	}
	
	return res;
}

void
DnsCache::Impl::removeTimedOut()
{
	// we are protected by the caller's guard
	
	if (m_container.empty()) {
		return;
	}
	
	TimeStamp now = TimeStamp::fromTimeval(ACE_OS::gettimeofday());
	SequenceIdx& idx = m_container.get<SequenceTag>();
	
	if (now < idx.front().timeout - TimeDelta::fromSec(TIMEOUT)) {
		// Time went back?
		m_container.clear();
		return;
	}
	
	while (!idx.empty()) {
		SequenceIdx::iterator it(idx.begin());
		if (it->timeout > now) {
			break;
		}
		idx.erase(it);
	}
}
