/* moveGenerator.t.cc
 */
#include "osl/search/moveGenerator.h"
#include "osl/search/searchState2.h"
#include "osl/search/shouldPromoteCut.h"
#include "osl/search/analyzer/categoryMoveVector.h"
#include "osl/move_generator/legalMoves.h"
#include "osl/move_classifier/pawnDropCheckmate.h"
#include "osl/move_classifier/moveAdaptor.h"
#include "osl/effect_util/effectUtil.h"
#include "osl/effect_util/pin.h"
#include "osl/eval/progressEval.h"
#include "osl/record/csaRecord.h"
#include "osl/record/csaString.h"
#include "osl/oslConfig.h"

#include <cppunit/TestCase.h>
#include <cppunit/extensions/HelperMacros.h>

#include <set>
#include <iterator>
#include <fstream>
#include <iostream>
#include <string>

class MoveGeneratorTest : public CppUnit::TestFixture 
{
  CPPUNIT_TEST_SUITE(MoveGeneratorTest);
  CPPUNIT_TEST(testAllMoves);
  CPPUNIT_TEST(testCopy);
  CPPUNIT_TEST_SUITE_END();
public:
  void testCopy();
  void testAllMoves();
};

CPPUNIT_TEST_SUITE_REGISTRATION(MoveGeneratorTest);

using namespace osl;
using namespace osl::search;

typedef SearchState2::checkmate_t checkmate_t;

void MoveGeneratorTest::testCopy()
{
  eval::ProgressEval::setUp();
  MoveGenerator::initOnce();
  
  extern bool isShortTest;
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  CPPUNIT_ASSERT(ifs);
  int i=0;
  int count=100;
  if (isShortTest) 
    count=10;
  checkmate_t checkmate;
  std::string file_name;
  while((ifs >> file_name) && ++i<count)
  {
    if(file_name == "") 
      break;
    if (! isShortTest)
      std::cerr << file_name << " ";
    Record rec=CsaFile(OslConfig::testCsaFile(file_name)).getRecord();
    NumEffectState initial_state(rec.getInitialState());
    SearchState2 sstate(initial_state, checkmate);
    vector<osl::Move> moves=rec.getMoves();
    MoveGenerator gen;
    for (unsigned int i=0; i<std::min((size_t)63,moves.size()); i++) {
      const int limit = 600;
      // generate all
      SimpleHashRecord record;
      record.setInCheck(sstate.state().inCheck());
      eval::ProgressEval eval(sstate.state());

      gen.init(limit, &record, eval, sstate.state(), true, Move());
      MoveGenerator gen2(gen);
      MoveLogProbVector search_moves;
      gen.generateAll(sstate.state().getTurn(), sstate, search_moves);

      // generate all by copy
      SearchState2 sstate_copy(sstate);
      MoveLogProbVector search_moves2;
      gen2.generateAll(sstate_copy.state().getTurn(), sstate_copy, search_moves2);

      CPPUNIT_ASSERT_EQUAL(search_moves, search_moves2);

      sstate.makeMove(moves[i]);
    }
  }  
}

void MoveGeneratorTest::testAllMoves()
{
  eval::ProgressEval::setUp();
  MoveGenerator::initOnce();
  
  extern bool isShortTest;
  std::ifstream ifs(OslConfig::testCsaFile("FILES"));
  CPPUNIT_ASSERT(ifs);
  int i=0;
  int count=100;
  if (isShortTest) 
    count=10;
  checkmate_t checkmate;
  std::string file_name;
  while((ifs >> file_name) && ++i<count)
  {
    if(file_name == "") 
      break;
    if (! isShortTest)
      std::cerr << file_name << " ";
    Record rec=CsaFile(OslConfig::testCsaFile(file_name)).getRecord();
    NumEffectState state(rec.getInitialState());
    vector<osl::Move> moves=rec.getMoves();
    MoveGenerator gen;
    for(unsigned int i=0;i<moves.size();i++)
    {
      MoveVector all_moves;
      LegalMoves::generate(state, all_moves);

      SimpleHashRecord record;
      record.setInCheck(state.inCheck());
      eval::ProgressEval eval(state);
      SearchState2 sstate(state, checkmate);
      gen.init(2000, &record, eval, state, true, Move());
      MoveLogProbVector search_moves;
      gen.generateAll(state.getTurn(), sstate, search_moves);
      if (search_moves.size() > all_moves.size()) {
	typedef std::set<Move> set_t;
	set_t s, t;
	for (size_t j=0; j<search_moves.size(); ++j) {
	  if (! s.insert(search_moves[j].getMove()).second) {
	    std::cerr << "\n" << state;
	    std::cerr << "dup " << search_moves[j].getMove() << "\n";
	    CPPUNIT_ASSERT(search_moves.size() <= all_moves.size());
	  }
	}
	for (size_t j=0; j<all_moves.size(); ++j) {
	  assert(t.insert(all_moves[j]).second);
	}
	vector<Move> diff;
	std::set_difference(s.begin(), s.end(), t.begin(), t.end(), 
			    std::back_inserter(diff));
	for (size_t j=0; j<diff.size(); ++j) {
	  if (move_classifier::PlayerMoveAdaptor<move_classifier::PawnDropCheckmate>
	      ::isMember(state, diff[j]))
	    continue;
	  std::cerr << "\n" << state;
	  std::cerr << "not legal " << *search_moves.find(diff[j]) << "\n";
	  analyzer::CategoryMoveVector a;
	  MoveGenerator g2;
	  eval::ProgressEval eval(state);
	  gen.init(400, &record, eval, state, true, Move());
	  gen.generateAll(state.getTurn(), sstate, a);
	  for (analyzer::CategoryMoveVector::const_iterator p=a.begin(); p!=a.end(); ++p) {
	    std::cerr << p->category << "\n";
	    std::cerr << p->moves << "\n";
	  }
	  
	  CPPUNIT_ASSERT(search_moves.size() <= all_moves.size());
	}
      }
      if (all_moves.size() != search_moves.size()) {
	for (size_t j=0; j<all_moves.size(); ++j) {
	  const Move m = all_moves[j];
	  if (ShouldPromoteCut::canIgnoreAndNotDrop(m))
	    continue;
	  if (! search_moves.find(m)) {
	    std::cerr << m << "\n" << state;
	    std::cerr << search_moves;
	  }
	  CPPUNIT_ASSERT(search_moves.find(m));
	}
      }

      ApplyMoveOfTurn::doMove(state, moves[i]);
    }
  }  
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
