/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2006  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "Socks4Requester.h"
#include "Reactor.h"
#include "InetAddr.h"
#include "SymbolicInetAddr.h"
#include <cstdlib>
#include <cassert>
#include <stddef.h>

using namespace std;

static unsigned char const SOCKS_VERSION_4 = 0x04;
static unsigned char const SOCKS_REPLY_VERSION_0 = 0x00;
static unsigned char const SOCKS_CMD_CONNECT = 0x01;
static unsigned char const SOCKS_STATUS_SUCCESS = 90;
static unsigned char const SOCKS_STATUS_REJECTED_OR_FAILED = 91;
static unsigned char const SOCKS_STATUS_IDENT_FAILED = 92;
static unsigned char const SOCKS_STATUS_WRONG_IDENT = 93;


Socks4Requester::Socks4Requester()
:	m_state(ST_INACTIVE)
{
}

Socks4Requester::~Socks4Requester()
{
}

void
Socks4Requester::requestConnection(
	Listener& listener, Reactor& reactor, ACE_HANDLE handle,
	InetAddr const& addr, std::string const& username)
{
	abort();
	if (addr.get_type() != AF_INET) {
		listener.onRequestFailure(SocksError(
			SocksError::UNSUPPORTED_ADDRESS_TYPE
		));
		return;
	}
	createConnectMsg(addr, username).swap(m_msgConnect);
	m_observerLink.setObserver(&listener);
	m_state = ST_SENDING_REQUEST; // must be before startWriting()
	m_readerWriter.activate(*this, reactor, handle);
	m_readerWriter.startWriting(&m_msgConnect[0], m_msgConnect.size());
}

void
Socks4Requester::requestConnection(
	Socks4aTag, Listener& listener, Reactor& reactor, ACE_HANDLE handle,
	SymbolicInetAddr const& addr, std::string const& username)
{
	abort();
	createConnectMsg(SOCKS4A, addr, username).swap(m_msgConnect);
	m_observerLink.setObserver(&listener);
	m_state = ST_SENDING_REQUEST; // must be before startWriting()
	m_readerWriter.activate(*this, reactor, handle);
	m_readerWriter.startWriting(&m_msgConnect[0], m_msgConnect.size());
}

void
Socks4Requester::abort()
{
	m_readerWriter.deactivate();
	m_observerLink.setObserver(0);
	std::vector<unsigned char>().swap(m_msgConnect);
	m_state = ST_INACTIVE;
}

void
Socks4Requester::onReadDone()
{
	assert(m_state == ST_RECEIVING_RESPONSE);
	onResponseReceived();
}

void
Socks4Requester::onReadError()
{
	handleRequestFailure(SocksError::CONNECTION_CLOSED);
}

void
Socks4Requester::onWriteDone()
{
	assert(m_state == ST_SENDING_REQUEST);
	
	m_state = ST_RECEIVING_RESPONSE; // must be before startReading()
	m_readerWriter.startReading(&m_response[0], 8);
}

void
Socks4Requester::onWriteError()
{
	handleRequestFailure(SocksError::CONNECTION_CLOSED);
}

void
Socks4Requester::onGenericError()
{
	handleRequestFailure(SocksError::GENERIC_ERROR);
}

void
Socks4Requester::onResponseReceived()
{
	if (m_response[0] != SOCKS_REPLY_VERSION_0) {
		handleRequestFailure(SocksError::PROTOCOL_VIOLATION);
		return;
	}
	
	unsigned char const status = m_response[1];
	if (status != SOCKS_STATUS_SUCCESS) {
		SocksError::Code code = socksStatusToErrorCode(status);
		handleRequestFailure(code);
		return;
	}
	
	handleRequestSuccess();
}

void
Socks4Requester::handleRequestFailure(SocksError::Code code)
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onRequestFailure(SocksError(code));
	}
}

void
Socks4Requester::handleRequestSuccess()
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onRequestSuccess();
	}
}

SocksError::Code
Socks4Requester::socksStatusToErrorCode(unsigned char status)
{
	switch (status) {
		case SOCKS_STATUS_REJECTED_OR_FAILED:
		return SocksError::REJECTED_OR_FAILED;
		case SOCKS_STATUS_IDENT_FAILED:
		case SOCKS_STATUS_WRONG_IDENT:
		return SocksError::AUTH_FAILURE;
	}
	return SocksError::SOCKS_SERVER_FAILURE;
}

std::vector<unsigned char>
Socks4Requester::createConnectMsg(
	InetAddr const& addr, std::string const& username)
{
	assert(addr.get_type() == AF_INET);
	sockaddr_in* saddr_in = (sockaddr_in*)addr.get_addr();
	vector<unsigned char> vec;
	vec.resize(9+username.size());
	unsigned char* ptr = &vec[0];
	*ptr++ = SOCKS_VERSION_4;
	*ptr++ = SOCKS_CMD_CONNECT;
	memcpy(ptr, &saddr_in->sin_port, 2);
	ptr += 2;
	memcpy(ptr, &saddr_in->sin_addr, 4);
	ptr += 4;
	memcpy(ptr, username.c_str(), username.size()+1);
	assert(ptr+username.size()+1 == &vec[0]+vec.size());
	return vec;
}

std::vector<unsigned char>
Socks4Requester::createConnectMsg(
	Socks4aTag, SymbolicInetAddr const& addr, std::string const& username)
{
	std::string const& host = addr.getHost();
	vector<unsigned char> vec;
	vec.resize(10+username.size()+host.size());
	unsigned char* ptr = &vec[0];
	*ptr++ = SOCKS_VERSION_4;
	*ptr++ = SOCKS_CMD_CONNECT;
	unsigned const port = addr.getPort();
	*ptr++ = port >> 8;
	*ptr++ = port;
	*ptr++ = 0x00;
	*ptr++ = 0x00;
	*ptr++ = 0x00;
	*ptr++ = 0x01;
	memcpy(ptr, username.c_str(), username.size()+1);
	ptr += username.size()+1;
	memcpy(ptr, host.c_str(), host.size()+1);
	assert(ptr+host.size()+1 == &vec[0]+vec.size());
	return vec;
}
