// learn.cc
/**
 * @file
 * ȿڿͽۤؿγؽ.
 */
#include "osl/checkmate/h_seed/seedMap.h"
#include "osl/checkmate/h_seed/features.h"
#include "osl/stat/instanceMultiplier.h"
#include "osl/stat/activityCount.h"
#include "osl/stat/diagonalPreconditioner.h"
#include "osl/stat/iterativeLinearSolver.h"
#include "osl/stat/twoDimensionalStatistics.h"
#include "osl/stat/average.h"
#include "osl/stat/sparseInstance.h"
#include <deque>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <iostream>
#include <fstream>
#include <unistd.h>

using namespace osl;
using namespace osl::stat;
using namespace osl::checkmate::h_seed;

void usage(const char *prog)
{
  using namespace std;
  cerr << "Usage: " << prog << " -H heuristic_type [-VC] [-E eps] [-I instances] [-P positions] [-T threshold]"
       << " [-w weights filename] [-s size] [-S saturation] [-M [pd]] \n"
       << " -V validate\n"
       << " -C count\n"
       << endl;
  exit(1);
}

void read(const char *x_filename, const char *y_filename);
void learn(const char *output_filename);
double validate(const char *weights_filename);
void count(const char *count_filename, unsigned int begin=0);

const char *instance_filename = "instances-l.txt";
const char *position_filename = "positions-l.txt";
const char *count_filename = "count.txt";

double eps = 0.00001;
int threshold = 250;
const unsigned int reserve_for_validation = 1024;

Features *features=0;

enum TargetMode { PROOF, DISPROOF };
TargetMode mode = PROOF;
unsigned int max_instances = static_cast<unsigned int>(-1);
int saturation = 1024;

int main(int argc, char **argv)
{
  nice(20);

  const char *weights_filename = "weights.txt";
  bool count_mode = false;
  bool validation_mode = false;
  SeedMap seeds;

  const char *program_name = argv[0];
  bool error_flag = false;
  extern char *optarg;
  extern int optind;
  char c;
  while ((c = getopt(argc, argv, "CH:E:I:M:P:w:s:S:T:Vvh")) != EOF)
  {
    switch(c)
    {
    case 'C':   count_mode = true;
      break;
    case 'E':	eps = atof(optarg);
      break;
    case 'H':	features = seeds.find(optarg);
      break;
    case 'I':   instance_filename = optarg;
      break;
    case 'M':	
      switch (optarg[0])
      {
      case 'p': mode = PROOF;
	break;
      case 'd': mode = DISPROOF;
	break;
      default:
	assert(0);
      }
      break;
    case 'P':   position_filename = optarg;
      break;
    case 's':	max_instances = atoi(optarg);
      break;
    case 'S':	saturation = atoi(optarg);
      break;
    case 'T':	threshold = atoi(optarg);
      break;
    case 'w':   weights_filename = optarg;
      break;
    case 'V':   validation_mode = true;
      break;
    default:	error_flag = true;
    }
  }
  argc -= optind;
  argv += optind;
  if (error_flag || (! features))
    usage(program_name);
  read(instance_filename, position_filename);
  if (validation_mode)
    validate(weights_filename);
  else if (count_mode)
    count(count_filename);
  else
    learn(weights_filename);

  return 0;
}


static SparseInstanceVector instances;
typedef std::deque<double> target_t;
static target_t target;

void read(const char *x_filename, const char *y_filename)
{
  std::cerr << "reading " << x_filename << ", " << y_filename << "\n";
  std::ifstream xs(x_filename);
  std::ifstream ys(y_filename);
  
  int x_record_id, x_move_id;
  int y_record_id, y_move_id, proof, disproof;
  while ((xs >> x_record_id >> x_move_id >> (*features))
	 && (ys >> y_record_id >> y_move_id >> proof >> disproof)
	 && (instances.size() < max_instances))
  {
    assert(x_record_id == y_record_id);
    assert(x_move_id == y_move_id);
    proof = std::min(proof, saturation);
    disproof = std::min(disproof, saturation);

    const SparseInstance a = features->makeInstance();
    instances.push_back(a);
    target.push_back((mode == PROOF) ? proof : disproof);
  }
  // assert(reserve_for_validation*10 < instances.size());
}

double validate(valarray_t& weights, size_t num_instances, bool verbose)
{
  TwoDimensionalStatistics stat;
  Average errors;
  num_instances = std::min(instances.size(), num_instances);
  const double intercept = weights[features->maxIndex()];
  for (size_t i=0; i<num_instances; ++i)
  {
    const double x = instances[i].dot_product(&weights[0], intercept);
    const double y = target[i];
    if (verbose)
      std::cout << x << " " << y << "\n";
    errors.add((y - x)*(y - x));
    stat.addInstance(x, y);
  }
  const double correlation = stat.correlation();
  std::cerr << correlation << "\n";
  std::cerr << errors.getAverage() << "\n";
  return correlation;
}

double validate(const char *weights_filename)
{
  valarray_t weights(features->maxIndex()+1);
  std::ifstream is(weights_filename);
  for (unsigned int i=0; i<weights.size(); ++i)
  {
    is >> weights[i];
  }
  assert(is);
  return validate(weights, instances.size(), true);
}


boost::scoped_ptr<ActivityCount> counts;
void count(const char *count_filename, unsigned int begin)
{
  counts.reset(new ActivityCount(features->maxIndex()+1));
  for (size_t i=begin; i<instances.size(); ++i)
  {
    const SparseInstance& a = instances[i];
    for (unsigned int j=0; j<a.size(); ++j)
    {
      counts->add(a[j].index);
    }
  }
  counts->add(features->maxIndex(), instances.size());
  counts->show(count_filename);
  counts->setBinary(threshold);
}

struct MyInstanceMultiplier : public InstanceMultiplier
{
  MyInstanceMultiplier(unsigned int dimension, 
		       unsigned int reserve_for_validation,
		       const SparseInstanceVector& instances,
		       const ActivityCount& activity,
		       valarray_t& w, int t, const char *f)
    : InstanceMultiplier(dimension, reserve_for_validation, instances,
			 activity, w, t, f)

  {
  }
  ~MyInstanceMultiplier()
  {
    newIteration();
  }
  double validate() const
  {
    return ::validate(weights, reserve_for_validation, false);
  }
};

void learn(const char *weights_filename)
{
  std::cerr << instances.size() << "\n";
  count(count_filename, reserve_for_validation);

  const size_t dim = features->maxIndex()+1;
  valarray_t result(0.0, dim);
  {
    std::ifstream is(weights_filename);
    for (size_t i=0; i<dim; ++i)
    {
      is >> result[i];
      if (! counts->isActive(i))
	result[i] = 0.0;
    }
    if (! is)
      std::cerr << "warning: initial weights read failure\n";
  }
  valarray_t b(dim);
  valarray_t diag_inv(dim);

  MyInstanceMultiplier prodA(dim, reserve_for_validation, instances, *counts,
			     result, threshold, weights_filename);

  std::cerr << "computing x^t y\n";
  DoubleIteratorReader<target_t::const_iterator> y(target.begin());
  prodA.computeXtY(y, &b[0], &diag_inv[0]);
  DiagonalPreconditioner preconditioner(dim);
  preconditioner.setInverseDiagonals(&diag_inv[0]);
  std::cerr << "preconditioner\n";

  IterativeLinearSolver solver(prodA, &preconditioner, 40, eps);
  std::cerr << "solver started ";
  int err = 0;
  int iter;
  double tol;
#if 0
  std::cerr << "using bicgstab\n";
  err = solver.solve_by_BiCGSTAB(b, result, &iter, &tol);
#else
  std::cerr << "using cg\n";
  err = solver.solve_by_CG(b, result, &iter, &tol);
#endif
  std::cerr << "tolerance achieved " << tol << "\n";
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
