///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <core/Core.h>
#include <core/utilities/ProgressIndicator.h>
#include "FindGrains.h"
#include "AnalyzeMicrostructureModifier.h"

#include <atomviz/utils/OnTheFlyNeighborList.h>

namespace CrystalAnalysis {

/******************************************************************************
* Default constructor.
******************************************************************************/
FindGrains::FindGrains()
{
	// Read in the space group rotations file.
	QFile spaceGroupFile(":/crystalanalysis/fcc_spacegroup.a1");
	spaceGroupFile.open(QIODevice::ReadOnly | QIODevice::Text);
	QTextStream spaceGroupStream(&spaceGroupFile);
	spaceGroupStream.readLine();
	Tensor2 rot(NULL_MATRIX);
	for(int i=0; i<48; i++) {
		for(int r=0; r<3; r++)
			for(int c=0; c<3; c++)
				spaceGroupStream >> rot(r,c);
		if(rot.determinant() > 0)
			_pointGroupRotations.push_back(rot);
	}
	OVITO_ASSERT(_pointGroupRotations.size() == 24);
}

struct GraphEdge
{
	int a, b;
	FloatType w;

	/// Compares the edges with respect to their weight.
	bool operator<(const GraphEdge& other) const { return w < other.w; }
};

// Disjoint-set forests using union-by-rank and path compression (sort of).
class DisjointSetForest
{
public:
	DisjointSetForest(int count, const Tensor2* orientations) : elements(count), numClusters(count) {
		for(QVector<Element>::iterator e = elements.begin(); e != elements.end(); ++e, ++orientations) {
			e->rank = 0;
			e->size = 1;
			e->p = e - elements.begin();
			e->orientation = *orientations;
		}
	}

	int getCluster(int atomIndex) {
		int y = atomIndex;
		while(y != elements[y].p)
			y = elements[y].p;
		elements[atomIndex].p = y;
		return y;
	}

	void joinClusters(int x, int y) {
		Element& ex = elements[x];
		Element& ey = elements[y];
		if(ex.rank > ey.rank) {
			ey.p = x;
			ex.size += ey.size;
			FloatType t1 = (FloatType)ey.size / (FloatType)ex.size;
			FloatType t2 = 1.0 - t1;
			for(size_t i=0; i<3; i++)
				for(size_t j=0; j<3; j++)
					ex.orientation(i,j) = ex.orientation(i,j)*t2 + ey.orientation(i,j)*t1;
		} else {
			ex.p = y;
			ey.size += ex.size;
			FloatType t1 = (FloatType)ex.size / (FloatType)ey.size;
			FloatType t2 = 1.0 - t1;
			for(size_t i=0; i<3; i++)
				for(size_t j=0; j<3; j++)
					ey.orientation(i,j) = ey.orientation(i,j)*t2 + ex.orientation(i,j)*t1;
			if(ex.rank == ey.rank)
				ey.rank++;
		}
		numClusters--;
	}

	int clusterSize(int cluster) const { return elements[cluster].size; }
	const Tensor2& clusterOrientation(int cluster) const { return elements[cluster].orientation; }
	void setClusterOrientation(int cluster, const Tensor2& q) { elements[cluster].orientation = q; }
	int numberOfClusters() const { return numClusters; }

private:

	struct Element {
		int rank;
		int p;
		int size;
		Tensor2 orientation;
	};

	QVector<Element> elements;
	int numClusters;
};

/******************************************************************************
* Performs the grain cluster analysis.
******************************************************************************/
bool FindGrains::performAnalysis(AtomsObject* input, DataChannel* outputClusterChannel, FloatType nearestNeighborCutoff,
		FloatType misorientationThreshold, int minCrystallineAtoms, DataChannel* misorientationChannel, bool suppressDialogs)
{
	CHECK_POINTER(input);
	CHECK_POINTER(outputClusterChannel);
	OVITO_ASSERT(minCrystallineAtoms >= 0);

	// Reset everything
	_grains.clear();

	// Show progress bar.
	ProgressIndicator progress(AnalyzeMicrostructureModifier::tr("Finding grains."), input->atomsCount(), suppressDialogs);

	// Check input data channels.
	DataChannel* deformationGradientChannel = input->getStandardDataChannel(DataChannel::DeformationGradientChannel);
	if(deformationGradientChannel == NULL)
		throw Exception(AnalyzeMicrostructureModifier::tr("Precalculated atomic crystal deformation gradients are required to divide the structure into grains. Please apply the Calculate Intrinsic Strain modifier first."));

	// Prepare output data channel.
	outputClusterChannel->setSize(input->atomsCount());
	fill(outputClusterChannel->dataInt(), outputClusterChannel->dataInt() + outputClusterChannel->size(), -1);
	if(misorientationChannel) {
		misorientationChannel->setSize(input->atomsCount());
		fill(misorientationChannel->dataFloat(), misorientationChannel->dataFloat() + misorientationChannel->size(), 0.0);
	}

	// Update progress bar.
	progress.setLabelText(AnalyzeMicrostructureModifier::tr("Building nearest-neighbor graph."));
	progress.isCanceled();

	// Prepare the neighbor list.
	OnTheFlyNeighborList neighborList(input, nearestNeighborCutoff);

	// Put together a list of weighted edges in the nearest neighbor graph.
	QVector<GraphEdge> bulkEdges;
	QVector<GraphEdge> grainBoundaryEdges;
	for(int i = 0; i < input->atomsCount(); i++) {

		// Update progress bar.
		if((i % 4096) == 0) {
			progress.setValue(i);
			if(progress.isCanceled()) {
				// Throw away results obtained so far if the user cancels the calculation.
				outputClusterChannel->setSize(0);
				return false;
			}
		}

		const Tensor2& F1 = deformationGradientChannel->getTensor2(i);
		FloatType det1 = F1.determinant();
		Tensor2 F1inverse;
		if(det1) F1inverse = F1.inverse();

		for(OnTheFlyNeighborList::iterator neighborIter(neighborList, i); !neighborIter.atEnd(); neighborIter.next()) {
			int j = neighborIter.current();
			if(j < i) continue;

			Tensor2 F2 = deformationGradientChannel->getTensor2(j);
			if(det1 != 0.0 && F2.determinant() != 0.0) {
				GraphEdge edge;
				edge.a = i;
				edge.b = j;
				edge.w = calculateMisorientation(F1inverse, F2);
				bulkEdges.push_back(edge);

				if(misorientationChannel) {
					FloatType angleInDegrees = edge.w * 180.0 / FLOATTYPE_PI;
					if(angleInDegrees > misorientationChannel->getFloat(i))
						misorientationChannel->setFloat(i, angleInDegrees);
					if(angleInDegrees > misorientationChannel->getFloat(j))
						misorientationChannel->setFloat(j, angleInDegrees);
				}
			}
			else {
				GraphEdge edge;
				edge.a = i;
				edge.b = j;
				grainBoundaryEdges.push_back(edge);
			}
		}
	}
	MsgLogger() << "Number of bulk graph edges:" << bulkEdges.size() << endl;
	MsgLogger() << "Number of grain boundary graph edges:" << grainBoundaryEdges.size() << endl;

	// Sort edges by weight.
	progress.setValue(0);
	progress.setLabelText(AnalyzeMicrostructureModifier::tr("Sorting graph edges."));
	progress.isCanceled();
	qSort(bulkEdges.begin(), bulkEdges.end());

	// Make a disjoint-set forest.
	DisjointSetForest u(input->atomsCount(), deformationGradientChannel->constDataTensor2());

	for(int iteration = 1; ; iteration++) {
		// Join components.
		progress.setValue(0);
		progress.setMaximum(bulkEdges.size());
		progress.setLabelText(AnalyzeMicrostructureModifier::tr("Performing segmentation (%1. pass).").arg(iteration));
		int oldNumberOfClusters = u.numberOfClusters();
		for(int i = 0; i < bulkEdges.size(); i++) {
			GraphEdge& edge = bulkEdges[i];

			// Update progress bar.
			if((i % 4096) == 0) {
				progress.setValue(i);
				if(progress.isCanceled()) {
					// Throw away results obtained so far if the user cancels the calculation.
					outputClusterChannel->setSize(0);
					return false;
				}
			}

			// Retreive components connected by this edge.
			int clusterA = u.getCluster(edge.a);
			int clusterB = u.getCluster(edge.b);

			// Skip edge if atoms already belong to the same cluster.
			if(clusterA == clusterB) continue;

			Tensor2 orientationB = u.clusterOrientation(clusterB);
			FloatType misorientation = calculateMisorientation(u.clusterOrientation(clusterA).inverse(), orientationB);
			OVITO_ASSERT(misorientation >= 0.0);

			// Join clusters if their misorientation is below the threshold.
			if(misorientation < misorientationThreshold) {
				u.setClusterOrientation(clusterB, orientationB);
				u.joinClusters(clusterA, clusterB);
			}
			else {
				// If the misorientation is above the threshold then join them as well if
				// one of them is too small to be a full grain.
				if(u.clusterSize(clusterA) < minCrystallineAtoms || u.clusterSize(clusterB) < minCrystallineAtoms) {
					u.setClusterOrientation(clusterB, orientationB);
					u.joinClusters(clusterA, clusterB);
				}
				else {
					if(misorientationChannel) {
						FloatType angleInDegrees = misorientation * 180.0 / FLOATTYPE_PI;
						if(angleInDegrees > misorientationChannel->getFloat(edge.a))
							misorientationChannel->setFloat(edge.a, angleInDegrees);
						if(angleInDegrees > misorientationChannel->getFloat(edge.b))
							misorientationChannel->setFloat(edge.b, angleInDegrees);
					}
				}
			}
		}
		// We are done when no join operation took place during the last iteration.
		if(u.numberOfClusters() == oldNumberOfClusters) break;
	}

	// Derive grains from clusters being large enough
	// and assign grain IDs to atoms.
	QMap<int, int> cluster2GrainMap;
	QBitArray assignedAtoms(input->atomsCount());
	for(int i = 0; i < input->atomsCount(); i++) {
		int clusterID = u.getCluster(i);
		QMap<int, int>::const_iterator iter = cluster2GrainMap.find(clusterID);
		if(iter == cluster2GrainMap.end()) {
			if(u.clusterSize(clusterID) >= minCrystallineAtoms) {
				int grainID = _grains.size();
				GrainInfo grain(grainID);
				grain._atomCount = grain._crystallineCount = u.clusterSize(clusterID);
				grain._averageOrientation = u.clusterOrientation(clusterID);
				_grains.push_back(grain);
				cluster2GrainMap[clusterID] = grainID;
				outputClusterChannel->setInt(i, grainID);
				assignedAtoms.setBit(i);
			}
		}
		else {
			int grainID = iter.value();
			outputClusterChannel->setInt(i, grainID);
			assignedAtoms.setBit(i);
		}
	}

	// Add remaining atoms from the grain boundaries to the grains.
	for(int iteration = 1; ; iteration++) {
		progress.setValue(0);
		progress.setMaximum(grainBoundaryEdges.size());
		progress.setLabelText(AnalyzeMicrostructureModifier::tr("Adding grain boundary atoms (%1. pass).").arg(iteration));
		bool finished = true;
		QBitArray newlyAssignedAtoms(input->atomsCount());
		for(int i = 0; i < grainBoundaryEdges.size(); i++) {
			GraphEdge& edge = grainBoundaryEdges[i];

			// Update progress bar.
			if((i % 4096) == 0) {
				progress.setValue(i);
				if(progress.isCanceled()) {
					// Throw away results obtained so far if the user cancels the calculation.
					outputClusterChannel->setSize(0);
					return false;
				}
			}

			if(assignedAtoms.at(edge.a) && !assignedAtoms.at(edge.b) && !newlyAssignedAtoms.at(edge.b)) {
				newlyAssignedAtoms.setBit(edge.b);
				int grainID = outputClusterChannel->getInt(edge.a);
				_grains[grainID]._atomCount++;
				outputClusterChannel->setInt(edge.b, grainID);
				finished = false;
			}
			else if(assignedAtoms.at(edge.b) && !assignedAtoms.at(edge.a) && !newlyAssignedAtoms.at(edge.a)) {
				newlyAssignedAtoms.setBit(edge.a);
				int grainID = outputClusterChannel->getInt(edge.b);
				_grains[grainID]._atomCount++;
				outputClusterChannel->setInt(edge.a, grainID);
				finished = false;
			}
		}
		if(finished) break;
		assignedAtoms |= newlyAssignedAtoms;
	}

	progress.setValue(0);
	progress.setMaximum(grainBoundaryEdges.size() + bulkEdges.size());
	progress.setLabelText(AnalyzeMicrostructureModifier::tr("Listing grain boundaries."));
	QMap< QPair<int, int>, int> gbMap;
	for(int i = 0; i < grainBoundaryEdges.size(); i++) {
		GraphEdge& edge = grainBoundaryEdges[i];
		// Update progress bar.
		if((i % 4096) == 0) {
			progress.setValue(i);
			if(progress.isCanceled()) {
				// Throw away results obtained so far if the user cancels the calculation.
				outputClusterChannel->setSize(0);
				return false;
			}
		}
		int grainA = outputClusterChannel->getInt(edge.a);
		int grainB = outputClusterChannel->getInt(edge.b);
		if(grainA != grainB && grainA >= 0 && grainB >= 0) {
			if(grainA > grainB) swap(grainA, grainB);
			QMap< QPair<int, int>, int>::const_iterator iter = gbMap.find(qMakePair(grainA, grainB));
			if(iter != gbMap.constEnd()) {
				_grainBoundaries[iter.value()]._numBonds++;
			}
			else {
				GrainBoundaryInfo gbInfo(_grainBoundaries.size());
				gbInfo._numBonds = 1;
				gbMap[qMakePair(grainA, grainB)] = _grainBoundaries.size();
				_grainBoundaries.push_back(gbInfo);
			}
		}
	}
	for(int i = 0; i < bulkEdges.size(); i++) {
		GraphEdge& edge = bulkEdges[i];
		// Update progress bar.
		if((i % 4096) == 0) {
			progress.setValue(i + grainBoundaryEdges.size());
			if(progress.isCanceled()) {
				// Throw away results obtained so far if the user cancels the calculation.
				outputClusterChannel->setSize(0);
				return false;
			}
		}
		int grainA = outputClusterChannel->getInt(edge.a);
		int grainB = outputClusterChannel->getInt(edge.b);
		if(grainA != grainB && grainA >= 0 && grainB >= 0) {
			if(grainA > grainB) swap(grainA, grainB);
			QMap< QPair<int, int>, int>::const_iterator iter = gbMap.find(qMakePair(grainA, grainB));
			if(iter != gbMap.constEnd()) {
				_grainBoundaries[iter.value()]._numBonds++;
			}
			else {
				GrainBoundaryInfo gbInfo(_grainBoundaries.size());
				gbInfo._numBonds = 1;
				gbMap[qMakePair(grainA, grainB)] = _grainBoundaries.size();
				_grainBoundaries.push_back(gbInfo);
			}
		}
	}


	return true;
}

/******************************************************************************
* Calculates the misorientation angle between two crystal orientations.
* Takes the symmetry of the cubic lattice into account.
******************************************************************************/
FloatType FindGrains::calculateMisorientation(const Tensor2& F1inverse, Tensor2& F2)
{
	// Calculate the direct angle.
	Tensor2 diffRot = F1inverse * F2;
	FloatType angle = Rotation(diffRot).angle;
	if(angle > FLOATTYPE_PI) angle = 2.0 * FLOATTYPE_PI - angle;
	OVITO_ASSERT(angle >= 0.0);

	if(angle < 45.0 * FLOATTYPE_PI / 180.0) return angle;

	// Find the best space group rotation.
	FloatType smallestAngle = FLOATTYPE_PI;
	Tensor2 originalOrientation = F2;
	for(QVector<Tensor2>::const_iterator rot = _pointGroupRotations.constBegin(); rot != _pointGroupRotations.constEnd(); ++rot) {
		angle = Rotation(diffRot * (*rot)).angle;
		if(angle > FLOATTYPE_PI) angle = 2.0 * FLOATTYPE_PI - angle;
		if(angle < smallestAngle) {
			smallestAngle = angle;
			F2 = originalOrientation * (*rot);
		}
	}

	OVITO_ASSERT(smallestAngle < 63.0 * FLOATTYPE_PI / 180.0);
	return smallestAngle;
}

/******************************************************************************
* Saves the class' contents to the given stream.
******************************************************************************/
void FindGrains::saveToStream(ObjectSaveStream& stream)
{
	stream.beginChunk(0x10000000);
	stream.writeSizeT(_grains.size());
	Q_FOREACH(const GrainInfo& grain, grains()) {
		stream.beginChunk(0x10000000);
		stream << grain._id;
		stream << grain._atomCount;
		stream << grain._crystallineCount;
		stream << grain._grainColor;
		stream << grain._averageOrientation;
		stream.endChunk();
	}
	stream.endChunk();

	stream.beginChunk(0x10000000);
	stream.writeSizeT(_grainBoundaries.size());
	Q_FOREACH(const GrainBoundaryInfo& gb, grainBoundaries()) {
		stream.beginChunk(0x10000000);
		stream << gb._id;
		stream << gb._grains[0];
		stream << gb._grains[1];
		stream.endChunk();
	}
	stream.endChunk();
}

/******************************************************************************
* Loads the class' contents from the given stream.
******************************************************************************/
void FindGrains::loadFromStream(ObjectLoadStream& stream)
{
	stream.expectChunk(0x10000000);

	size_t numGrains;
	stream.readSizeT(numGrains);
	_grains.resize(numGrains);
	for(QVector<GrainInfo>::iterator grain = _grains.begin(); grain != _grains.end(); ++grain) {
		stream.expectChunk(0x10000000);
		stream >> grain->_id;
		stream >> grain->_atomCount;
		stream >> grain->_crystallineCount;
		stream >> grain->_grainColor;
		stream >> grain->_averageOrientation;
		stream.closeChunk();
	}

	stream.closeChunk();

	stream.expectChunk(0x10000000);

	size_t numGrainBoundaries;
	stream.readSizeT(numGrainBoundaries);
	_grainBoundaries.resize(numGrainBoundaries);
	for(QVector<GrainBoundaryInfo>::iterator gb = _grainBoundaries.begin(); gb != _grainBoundaries.end(); ++gb) {
		stream.expectChunk(0x10000000);
		stream >> gb->_id;
		stream >> gb->_grains[0];
		stream >> gb->_grains[1];
		stream.closeChunk();
	}

	stream.closeChunk();
}

};	// End of namespace CrystalAnalysis
