///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <core/Core.h>
#include <atomviz/atoms/AtomsObject.h>
#include <atomviz/atoms/datachannels/PositionDataChannel.h>
#include "OnTheFlyNeighborList.h"

namespace AtomViz {

/******************************************************************************
* Constructor
******************************************************************************/
OnTheFlyNeighborList::OnTheFlyNeighborList(AtomsObject* input, FloatType cutoffRadius) : _cutoffRadius(cutoffRadius)
{
	CHECK_OBJECT_POINTER(input);

	PositionDataChannel* posChannel = static_object_cast<PositionDataChannel>(input->getStandardDataChannel(DataChannel::PositionChannel));
	if(!posChannel) throw Exception("Input object does not contain atomic positions. Position channel is missing.");

	_cutoffRadiusSquared = cutoffRadius * cutoffRadius;

	if(cutoffRadius <= 0.0)
		throw Exception("Invalid parameter: Neighbor list cutoff radius must be positive.");

	simCell = input->simulationCell()->cellMatrix();
	if(simCell.determinant() <= FLOATTYPE_EPSILON)
		throw Exception("Simulation cell is degenerate.");
	simCellInverse = simCell.inverse();
	pbc = input->simulationCell()->periodicity();

	// Calculate the number of bins required in each spatial direction.
	binDim[0] = binDim[1] = binDim[2] = 1;
	if(cutoffRadius > 0.0) {
		AffineTransformation m = AffineTransformation::scaling(cutoffRadius) * simCellInverse;
		for(size_t i=0; i<3; i++) {
			binDim[i] = (int)(Length(simCell.column(i)) / cutoffRadius);
			binDim[i] = min(binDim[i], (int)(1.0 / Length(m.column(i))));
			binDim[i] = min(binDim[i], 50);
			if(binDim[i] < 1) {
				if(pbc[i]) throw Exception("Periodic simulation cell is smaller than the neighbor cutoff radius.");
				binDim[i] = 1;
			}
		}
	}

	typedef BinsArray::iterator iterator1;
	typedef subarray_gen<BinsArray,2>::type::iterator iterator2;
	typedef subarray_gen<BinsArray,1>::type::iterator iterator3;

	bins.resize(extents[binDim[0]][binDim[1]][binDim[2]]);

	// Clear bins.
	for(iterator1 iter1 = bins.begin(); iter1 != bins.end(); ++iter1)
		for(iterator2 iter2 = (*iter1).begin(); iter2 != (*iter1).end(); ++iter2)
			for(iterator3 iter3 = (*iter2).begin(); iter3 != (*iter2).end(); ++iter3)
				(*iter3) = NULL;

	// Put atoms into bins.
	atoms.resize(input->atomsCount());

	const Point3* p = posChannel->constDataPoint3();
	QVector<NeighborListAtom>::iterator a = atoms.begin();
	int atomIndex = 0;
	for(; a != atoms.end(); ++a, ++p, ++atomIndex) {
		a->index = atomIndex;

		// Transform atom position from absolute coordinates to reduced coordinates.
		a->pos = *p;
		Point3 reducedp = simCellInverse * (*p);

		int indices[3];
		for(size_t k=0; k<3; k++) {
			// Shift atom position to be inside simulation cell.
			if(pbc[k]) {
				while(reducedp[k] < 0) {
					reducedp[k] += 1;
					a->pos += simCell.column(k);
				}
				while(reducedp[k] > 1) {
					reducedp[k] -= 1;
					a->pos -= simCell.column(k);
				}
			}
			else {
				reducedp[k] = max(reducedp[k], (FloatType)0);
				reducedp[k] = min(reducedp[k], (FloatType)1);
			}

			// Determine the atom's bin from its position.
			indices[k] = max(min((int)(reducedp[k] * binDim[k]), binDim[k]-1), 0);
			OVITO_ASSERT(indices[k] >= 0 && indices[k] < (int)bins.shape()[k]);
		}

		// Put atom into its bin.
		NeighborListAtom** binList = &bins[indices[0]][indices[1]][indices[2]];
		a->nextInBin = *binList;
		*binList = &*a;
	}
}

/******************************************************************************
* Tests whether two atoms are closer to each other than the
* nearest-neighbor cutoff radius.
******************************************************************************/
bool OnTheFlyNeighborList::areNeighbors(int atom1, int atom2) const
{
	OVITO_ASSERT(atom1 >= 0 && atom1 < atoms.size());
	OVITO_ASSERT(atom2 >= 0 && atom2 < atoms.size());
	OVITO_ASSERT(atom1 != atom2);
	for(iterator neighborIter(*this, atom1); !neighborIter.atEnd(); neighborIter.next()) {
		if(neighborIter.current() == atom2) return true;
	}
	return false;
}

/******************************************************************************
* Iterator constructor
******************************************************************************/
OnTheFlyNeighborList::iterator::iterator(const OnTheFlyNeighborList& neighborList, int atomIndex)
	: _list(neighborList), centerindex(atomIndex)
{
	dir[0] = -2;
	dir[1] = 1;
	dir[2] = 1;
	binatom = NULL;
	center = _list.atoms[atomIndex].pos;
	neighborindex = -1;

	// Determine the bin the central atom is located in.
	// Transform atom position from absolute coordinates to reduced coordinates.
	OVITO_ASSERT(atomIndex >= 0 && atomIndex < _list.atoms.size());
	Point3 reducedp = _list.simCellInverse * center;

	for(size_t k=0; k<3; k++) {
		// Determine the atom's bin from its position.
		centerbin[k] = max(min((int)(reducedp[k] * _list.binDim[k]), _list.binDim[k]-1), 0);
		OVITO_ASSERT(centerbin[k] >= 0 && centerbin[k] < (int)_list.bins.shape()[k]);
	}

	next();
}

/******************************************************************************
* Iterator function.
******************************************************************************/
int OnTheFlyNeighborList::iterator::next()
{
	while(dir[0] != 2) {
		while(binatom) {
			_delta = binatom->pos - center - pbcshift;
			neighborindex = binatom->index;
			OVITO_ASSERT(neighborindex >= 0 && neighborindex < _list.atoms.size());
			binatom = binatom->nextInBin;
			distsq = LengthSquared(_delta);
			if(distsq <= _list._cutoffRadiusSquared && neighborindex != centerindex) {
				return neighborindex;
			}
		};
		if(dir[2] == 1) {
			dir[2] = -1;
			if(dir[1] == 1) {
				dir[1] = -1;
				if(dir[0] == 1) {
					dir[0]++;
					neighborindex = -1;
					return -1;
				}
				else dir[0]++;
			}
			else dir[1]++;
		}
		else dir[2]++;

		currentbin[0] = centerbin[0] + dir[0];
		if(currentbin[0] == -1 && !_list.pbc[0]) continue;
		if(currentbin[0] == _list.binDim[0] && !_list.pbc[0]) continue;

		currentbin[1] = centerbin[1] + dir[1];
		if(currentbin[1] == -1 && !_list.pbc[1]) continue;
		if(currentbin[1] == _list.binDim[1] && !_list.pbc[1]) continue;

		currentbin[2] = centerbin[2] + dir[2];
		if(currentbin[2] == -1 && !_list.pbc[2]) continue;
		if(currentbin[2] == _list.binDim[2] && !_list.pbc[2]) continue;

		pbcshift = NULL_VECTOR;
		for(size_t k = 0; k < 3; k++) {
			if(currentbin[k] == -1) { currentbin[k] = _list.binDim[k]-1; pbcshift += _list.simCell.column(k); }
			else if(currentbin[k] == _list.binDim[k]) { currentbin[k] = 0; pbcshift -= _list.simCell.column(k); }
		}

		binatom = _list.bins[currentbin[0]][currentbin[1]][currentbin[2]];
	}
	neighborindex = -1;
	return -1;
}

};	// End of namespace AtomViz

