/* pvGenerator.cc
 */
#include "pvGenerator.h"
#include "pvFile.h"
#include "quiesce.h"
#include "analyzer.h"

#include "osl/record/kisen.h"
#include "osl/record/csaIOError.h"
#include "osl/rating/featureSet.h"
#include "osl/rating/ratingEnv.h"
#include "osl/search/shouldPromoteCut.h"
#include "osl/progress/effect5x3.h"
#include "osl/move_generator/legalMoves.h"
#include "osl/progress/ml/newProgress.h"
#include "osl/apply_move/applyMove.h"

#include <boost/foreach.hpp>
#include <algorithm>
#include <numeric>
#include <sstream>
#include <iostream>

int gpsshogi::KisenAnalyzer::quiesce_depth = 4;
int gpsshogi::KisenAnalyzer::use_percent = 100;
bool gpsshogi::KisenAnalyzer::compare_pass = false;
bool gpsshogi::KisenAnalyzer::more_depth_in_opening = false;
bool gpsshogi::PVGenerator::limit_sibling_10 = false;

gpsshogi::KisenAnalyzer::
RecordConfig::RecordConfig(size_t tid, size_t f, size_t l, const std::string& kisen, bool a)
  : first(f), last(l), thread_id(tid), kisen_filename(kisen), allow_skip_in_cross_validation(a)
{
}

gpsshogi::
KisenAnalyzer::KisenAnalyzer(const RecordConfig& r, const OtherConfig& c, Result *out)
  : record(r), config(c), result(out)
{
}

gpsshogi::
KisenAnalyzer::~KisenAnalyzer()
{
}

bool gpsshogi::
KisenAnalyzer::isCrossValidation() const
{
  return false;
}

void gpsshogi::
KisenAnalyzer::init() 
{
}

void gpsshogi::KisenAnalyzer::operator()()
{
  result->werrors.clear();
  result->toprated.clear();
  result->toprated_strict.clear();
  result->order_lb.clear();
  result->order_ub.clear();
  result->record_processed = result->skip_by_rating = 0;
  result->node_count = 0;
  result->all_errors.clear();

  init();
  KisenFile kisen_file(record.kisen_filename.c_str());
  boost::scoped_ptr<KisenIpxFile> ipx;
  if (config.min_rating)
    ipx.reset(new KisenIpxFile(KisenFile::ipxFileName(record.kisen_filename)));

  vector<size_t> orders(record.last - record.first);
  for (size_t i=0; i<orders.size(); ++i)
    orders[i] = i + record.first;
  if (use_percent < 100) {
    std::random_shuffle(orders.begin(), orders.end());
    orders.resize(orders.size()*use_percent/100);
  }
  BOOST_FOREACH(size_t i, orders) {
    if (record.allow_skip_in_cross_validation && isCrossValidation()) {
      if (i % 4)
	continue;
    }    
    if (i % 100 == 0) {
      if (i % ((record.thread_id == 0) ? 500 : 1000) == 0) {
	std::cerr << i << " ";
      }
      else
	std::cerr << ".";
    }
    if (config.min_rating
	&& (ipx->getRating(i, BLACK) < config.min_rating 
	    || ipx->getRating(i, WHITE) < config.min_rating)) {
      ++(result->skip_by_rating);
      continue;
    }
#if 0
    std::cerr << "go quiesce " << i << "\n";
#endif
    ++result->record_processed;

    NumEffectState state(kisen_file.getInitialState());
    const vector<Move> moves=kisen_file.getMoves(i);
#ifdef LEARN_DEEP_SEARCH
    Quiesce quiesce(config.my_eval, 2, 5);
#else
    Quiesce quiesce(config.my_eval, 1, quiesce_depth); 
#endif
    osl::progress::ml::NewProgress progress(state);
    stat::Average record_errors;
    for (size_t j=0; j<moves.size(); j++) {
      const Player turn = state.getTurn();
      // 自分の手番で相手の王が利きがある => 直前の手が非合法手
      if (state.inCheck(alt(turn))) {
	std::cerr << "e"; // state;
	break;
      }
      if (config.max_progress < 16) {
	if (progress.progress16().value() > config.max_progress)
	  break;
      }
      if (! isCrossValidation() || (j+config.cross_validation_randomness) % 4 == 0) {
	assert(j < moves.size());
	forEachPosition(i, j, quiesce, state, progress.progress16(), moves[j]);
	if (isCrossValidation())
	  record_errors.add(result->last_error);
      }
      result->node_count += quiesce.nodeCount();
      quiesce.clear();
      ApplyMoveOfTurn::doMove(state, moves[j]);
      progress.update(state, moves[j]);
    }

    if (isCrossValidation())
      result->all_errors.push_back(boost::make_tuple
				   (i, record_errors.getAverage()));    
  }
  std::cerr << "#" << std::flush;
}

#define VERBOSE_ASSIGN

void gpsshogi::
KisenAnalyzer::splitFile(const std::string& file, size_t first, size_t last, int num_assignment, double average_records, 
			 RecordConfig *out, int& written, bool verbose)
{
  for (int j=0; j<num_assignment; ++j) {
    size_t last_j = (j+1 == num_assignment) ? last : first+(int)floor(average_records);
#ifdef VERBOSE_ASSIGN
    if (verbose)
      std::cerr << "   " << written << " " << file << " " << first << " " << last_j << std::endl;
#endif
    out[written] = RecordConfig(written, first, last_j, file);
    ++written;
    first = last_j;
  }
}

void gpsshogi::
KisenAnalyzer::splitFileWithMoves(const std::string& file, size_t first, size_t last, int num_assignment, double /*average_records*/, 
				  RecordConfig *out, int& written, bool verbose)
{
  if (num_assignment == 0 || first == last)
    return;
  KisenFile kisen_file(file);
  vector<int> moves(last-first+1);
  for (size_t i=first; i<last; ++i)
    moves[i-first+1] = 
      kisen_file.getMoves(i).size()
      + std::max((int)kisen_file.getMoves(i).size()-60, 0)
      + std::max((int)kisen_file.getMoves(i).size()-120, 0)
      + std::max((int)kisen_file.getMoves(i).size()-180, 0)
      + moves[i-first];
  const int average_moves = moves.back()/num_assignment;
  size_t cur = first;
  for (int j=0; j<num_assignment; ++j) {
    size_t last_j;
    if (j+1 == num_assignment) 
    {
      last_j = last;
    }
    else 
    {
      last_j = cur+1;
      while (moves[last_j-first] - moves[cur-first] < average_moves
	     && (last_j == cur+1
		 || (average_moves - (moves[last_j-first] - moves[cur-first])
		     > (moves[last_j+1-first] - moves[cur-first]) - average_moves)))
	++last_j;
    }
#ifdef VERBOSE_ASSIGN
    if (verbose)
      std::cerr << "   " << written << " " << file << " " << cur << " " << last_j
		<< " (" << moves[last_j-first] - moves[cur-first] << ")" << std::endl;
#endif
    out[written] = RecordConfig(written, cur, last_j, file);
    ++written;
    cur = last_j;
  }
}

void gpsshogi::
KisenAnalyzer::distributeJob(size_t split, RecordConfig *out, size_t kisen_start, size_t num_records,
			     const std::vector<std::string>& files, size_t /*min_rating*/)
{
  size_t total = 0;
  vector<size_t> totals(files.size()+1), sizes(files.size());
  for (size_t i=0; i<files.size(); ++i) {
    totals[i] = total;
    try {
      KisenFile kisen_file(files[i].c_str());
      sizes[i] = kisen_file.size(); // todo min_rating
#ifdef VERBOSE_ASSIGN
      if (files.size() > 1)
	std::cerr << "  size " << files[i] << " " << sizes[i] << std::endl;
#endif
    }
    catch (osl::CsaIOError&) {
      std::cerr << "open failed " << files[i] << "\n";
      sizes[i] = 0;
    }
    total += sizes[i];
  }
  totals[files.size()] = total;
  if (total <= kisen_start) {
    std::cerr << "warning KisenAnalyzer::distributeJob kisen_start too large "
	      << kisen_start << " >= " << total << std::endl;
    return;
  }
  // find the range of files to be processed
  if (num_records == 0 || kisen_start+num_records > total)
    num_records = total - kisen_start;

  size_t start_file = 0;
  while (start_file < files.size() && totals[start_file+1] <= kisen_start)
    ++start_file;
  size_t end_file = start_file+1;
  while (end_file <= files.size() && totals[end_file] < kisen_start+num_records)
    ++end_file;

  // first, try to assgin at least one cpu for each file
  int cpu_left = split;
  const int offset = kisen_start - totals[start_file];
  vector<int> assignment(files.size()); 
  vector<double> amount(files.size());
  for (size_t i=0; i<std::min(split, end_file-start_file); ++i) {
    const int file = start_file+i;
    assignment[file] = 1;
    amount[file] = sizes[file];
    if (i == 0) 
    {
      assert(amount[start_file] > offset);
      amount[start_file] -= offset;
    }
    if (start_file+i+1 == end_file) 
    {
      amount[end_file-1] -= (totals[end_file]-num_records-offset-totals[start_file]);
      assert(amount[end_file-1] > 0);
    }
    --cpu_left;
  }
  // then, assign other cpus to the largest files in a greedy way.
  while (cpu_left--) {
    size_t file = std::max_element(amount.begin(), amount.end()) - amount.begin();
    double total = amount[file]*assignment[file];
    assignment[file]++;
    amount[file] = total / assignment[file];
  }
  
  // write out
  int written = 0;
  for (size_t i=0; i<std::min(split, end_file-start_file); ++i) {
    const size_t file_id = start_file+i;
    if (assignment[file_id] == 0)
      continue;
    size_t first = 0, last = sizes[file_id];
    if (i == 0)
      first = kisen_start - totals[start_file];
    if (file_id+1 == end_file) 
      last -= (totals[end_file]-num_records-offset-totals[start_file]);
#if 1
    splitFileWithMoves(files[file_id], first, last, assignment[file_id], amount[file_id], out, written,
		       files.size() > 1);
#else
    splitFile(files[file_id], first, last, assignment[file_id], amount[file_id], out, written,
	      files.size() > 1);
#endif
  }
}

/* ------------------------------------------------------------------------- */

gpsshogi::
Validator::Validator(const RecordConfig& r, const OtherConfig& c, Result *out)
  : KisenAnalyzer(r, c, out)
{
}

gpsshogi::
Validator::~Validator()
{
}

bool gpsshogi::
Validator::isCrossValidation() const
{
  return true;
}

void gpsshogi::
Validator::forEachPosition(int /*record_id*/, int /*position_id*/, 
			   Quiesce& quiesce, const NumEffectState& state, 
			   osl::Progress16, Move best_move)
{
  const int turn_coef = (state.getTurn() == BLACK) ? 1 : -1;
  double cur_errors = 0.0;
  result->last_error = 0.0;
  PVVector pv;
  pv.push_back(best_move);
  int best_value;  
  {
    NumEffectState new_state = state;
    ApplyMoveOfTurn::doMove(new_state, best_move);
    if (! quiesce.quiesce(new_state, best_value, pv))
      return;
  }
  if (abs(best_value) == quiesce.infty(BLACK))
    return;

  MoveVector moves;
  LegalMoves::generate(state, moves);
  if (compare_pass && ! state.inCheck())
    moves.push_back(Move::PASS(state.getTurn()));

  size_t move_id = 0;
  vector<int> values; values.reserve(moves.size());
  values.push_back(best_value*turn_coef);
  for (MoveVector::const_iterator p=moves.begin(); p!=moves.end(); ++p, ++move_id) {
    if (*p == best_move)
      continue;
    pv.clear();
    int value;
    {
      NumEffectState new_state = state;
      ApplyMoveOfTurn::doMove(new_state, *p);
      pv.push_back(*p);
      const int width = (int)(config.window_by_pawn*config.my_eval->pawnValue());
      if (! quiesce.quiesce(new_state, value, pv, best_value - width, best_value + width))
	continue;
    }
    if (abs(value) == quiesce.infty(BLACK))
      continue;
    values.push_back(value*turn_coef);
    cur_errors += SigmoidUtil::tx((value - best_value)*turn_coef, config.my_eval->pawnValue());
  }
  std::sort(values.begin(), values.end());
  result->toprated.add(values.back()== best_value*turn_coef); // top?
  result->toprated_strict.add(values.back() == best_value*turn_coef
			      && values.size() > 1 && values[values.size()-2] != best_value*turn_coef);
  result->order_lb.add(values.end() - std::lower_bound(values.begin(), values.end(), best_value));
  result->order_ub.add(values.end() - std::upper_bound(values.begin(), values.end(), best_value));
  result->werrors.add(cur_errors);
  result->last_error = cur_errors;
}

/* ------------------------------------------------------------------------- */

gpsshogi::
PVGenerator::PVGenerator(const std::string& pv_base, 
			 const RecordConfig& r, const OtherConfig& c, Result *out)
  : KisenAnalyzer(r, c, out),
    pv_filename(pv_file(pv_base, r.thread_id))
{
}

gpsshogi::
PVGenerator::~PVGenerator()
{
}

const std::string gpsshogi::
PVGenerator::pv_file(const std::string& pv_base, size_t thread_id) 
{
  std::ostringstream ss;
  ss << pv_base << thread_id << ".gz";
  return ss.str();
}

void gpsshogi::
PVGenerator::init() 
{
  pw.reset(new PVFileWriter(pv_filename.c_str()));
}

void gpsshogi::
PVGenerator::forEachPosition(int record_id, int position_id, 
			     Quiesce& quiesce, const NumEffectState& state, 
			     osl::Progress16 progress, Move best_move)
{
  const bool in_opening = more_depth_in_opening && progress.value() == 0
    && ((record_id+position_id)%3 == 0);

  PVVector pv;
  pv.push_back(best_move);
  int best_value;
  const int full_depth = quiesce.fullWidthDepth();
  const int quiesce_depth = quiesce.quiesceDepth();
  {
    NumEffectState new_state = state;
    ApplyMoveOfTurn::doMove(new_state, best_move);
    if (in_opening)
      quiesce.setDepth(0, quiesce_depth);
    const bool good_pv = quiesce.quiesce(new_state, best_value, pv);
    if (in_opening)
      quiesce.setDepth(full_depth, quiesce_depth);
    if (! good_pv)
      return;
#ifdef DEBUG_ALL
    std::cerr << "*BEST*\n" << new_state << pv;
#endif
  }
  if (abs(best_value) == quiesce.infty(BLACK))
    return;
#ifdef WATCH_PV
  std::cerr << state << pv;
#endif
  pw->newPosition(record_id, position_id);
  pw->addPv(pv);

  MoveVector moves;
  LegalMoves::generate(state, moves);
  if (compare_pass && ! state.inCheck())
    moves.push_back(Move::PASS(state.getTurn()));

  const int turn_coef = (state.getTurn() == BLACK) ? 1 : -1;
  double cur_errors = 0.0;
  size_t move_id = 0;
  typedef vector<std::pair<int,PVVector> > vector_t;
  vector_t values;
  values.reserve(in_opening ? moves.size()*2 : moves.size());
  for (int d=0; d<(in_opening ? 2 : 1); ++d) 
  {
    if (in_opening)
      quiesce.setDepth((d == 0) ? 2 : 0, quiesce_depth);
    for (MoveVector::const_iterator p=moves.begin(); p!=moves.end(); ++p, ++move_id) {
      if (*p == best_move)
	continue;
      if (! p->isPass()
	  && osl::search::ShouldPromoteCut::canIgnoreAndNotDrop(*p)
	  && state.hasEffectBy(alt(p->player()), p->to()))
	continue;

      pv.clear();
      int value;
      {
	NumEffectState new_state = state;
	ApplyMoveOfTurn::doMove(new_state, *p);
	pv.push_back(*p);
	const int width = (int)(config.window_by_pawn*config.my_eval->pawnValue());
	if (! quiesce.quiesce(new_state, value, pv, best_value - width, best_value + width))
	  continue;
      }
      if (abs(value) == quiesce.infty(BLACK))
	continue;
      if (d > 0 && pv.size() > 2)
	break;
#ifdef WATCH_PV
      if (p - moves.begin() < 4)
	std::cerr << pv;
#endif
#ifdef DEBUG_PV
      {
	HistoryState stack;
	Analyzer::makeLeaf(stack, pv);
	if (stack.state.inCheck(BLACK) || stack.state.inCheck(WHITE)) {
	  // continue;
	  std::cerr << state << pv << value << "\n" << record_id << " "
		    << position_id << " " << move_id << "\n";
	}

	assert(! stack.state.inCheck(BLACK));
	assert(! stack.state.inCheck(WHITE));
      }
#endif
      cur_errors += SigmoidUtil::tx((value - best_value)*turn_coef, config.my_eval->pawnValue());
      values.push_back(std::make_pair(value*turn_coef, pv));
    }
  }
  if (in_opening)	// restore
    quiesce.setDepth(full_depth, quiesce_depth);
  if (values.empty())
    return;
  if (! limit_sibling_10) 
  {
    vector_t::const_iterator p
      = std::max_element(values.begin(), values.end());
    result->toprated.add(p->first <= best_value*turn_coef);
    result->toprated_strict.add(p->first < best_value*turn_coef);    
  }
  else
  {
    std::sort(values.begin(), values.end());
    result->toprated.add(values.back().first <= best_value*turn_coef);
    result->toprated_strict.add(values.back().first < best_value*turn_coef);
    size_t i=0;
    for (;i<values.size(); ++i) 
      if (values[i].first > best_value*turn_coef)
	break;
    values.resize(std::min(i+10, values.size()));
  }
  for (size_t i=0; i<values.size(); ++i)
    pw->addPv(values[i].second);
  result->werrors.add(cur_errors);
  result->last_error = cur_errors;
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; coding:utf-8
// ;;; End:
