/*
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * Written (W) 1999-2007 Soeren Sonnenburg
 * Copyright (C) 1999-2007 Fraunhofer Institute FIRST and Max-Planck-Society
 */

#include "classifier/svm/GPBTSVM.h"
#include "classifier/svm/gpdt.h"
#include "classifier/svm/gpdtsolve.h"
#include "lib/io.h"

CGPBTSVM::CGPBTSVM() : CSVM(), model(NULL)
{
}

CGPBTSVM::CGPBTSVM(DREAL C, CKernel* k, CLabels* lab) : CSVM(C, k, lab), model(NULL)
{
}

CGPBTSVM::~CGPBTSVM()
{
	free(model);
}

bool CGPBTSVM::train()
{
	double     *solution;                     /* store the solution found       */
	QPproblem  prob;                          /* object containing the solvers  */

	ASSERT(get_kernel());
	ASSERT(get_labels() && get_labels()->get_num_labels());
	ASSERT(get_labels()->is_two_class_labeling());

	int num_lab = 0;
	int* lab=get_labels()->get_int_labels(num_lab);
	prob.KER=new sKernel(get_kernel(), num_lab);
	prob.y=lab;
	ASSERT(prob.KER);
	prob.ell=get_labels()->get_num_labels();
	SG_INFO( "%d trainlabels\n", prob.ell);

	//  /*** set options defaults ***/
	prob.delta = epsilon;
	prob.maxmw = get_kernel()->get_cache_size();
	prob.verbosity       = 0;
	prob.preprocess_size = -1;
	prob.projection_projector = -1;
	prob.c_const = get_C1();
	prob.chunk_size = get_qpsize();
	prob.linadd = get_linadd_enabled();

	if (prob.chunk_size < 2)      prob.chunk_size = 2;
	if (prob.q <= 0)              prob.q = prob.chunk_size / 3;
	if (prob.q < 2)               prob.q = 2;
	if (prob.q > prob.chunk_size) prob.q = prob.chunk_size;
	prob.q = prob.q & (~1);
	if (prob.maxmw < 5)
		prob.maxmw = 5;

	/*** set the problem description for final report ***/
	SG_INFO( "\nTRAINING PARAMETERS:\n");
	SG_INFO( "\tNumber of training documents: %d\n", prob.ell);
	SG_INFO( "\tq: %d\n", prob.chunk_size);
	SG_INFO( "\tn: %d\n", prob.q);
	SG_INFO( "\tC: %lf\n", prob.c_const);
	SG_INFO( "\tkernel type: %d\n", prob.ker_type);
	SG_INFO( "\tcache size: %dMb\n", prob.maxmw);
	SG_INFO( "\tStopping tolerance: %lf\n", prob.delta);

	//  /*** compute the number of cache rows up to maxmw Mb. ***/
	if (prob.preprocess_size == -1)
		prob.preprocess_size = (int) ( (double)prob.chunk_size * 1.5 );

	if (prob.projection_projector == -1)
	{
		if (prob.chunk_size <= 20) prob.projection_projector = 0;
		else prob.projection_projector = 1;
	}

	/*** compute the problem solution *******************************************/
	solution = new double[prob.ell];
	prob.gpdtsolve(solution);
	/****************************************************************************/

  CSVM::set_objective(prob.objective_value);

	int num_sv=0;
	int bsv=0;
	int i=0;
	int k=0;

	for (i = 0; i < prob.ell; i++)
	{
		if (solution[i] > prob.DELTAsv)
		{
			num_sv++;
			if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++;
		}
	}

	create_new_model(num_sv);
	set_bias(prob.bee);

	SG_INFO("SV: %d BSV = %d\n", num_sv, bsv);

	for (i = 0; i < prob.ell; i++)
	{
		if (solution[i] > prob.DELTAsv)
		{
			set_support_vector(k, i);
			set_alpha(k++, solution[i]*get_labels()->get_label(i));
		}
	}

	delete[] solution;
	delete[] lab;

	return true;
}
