///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <core/Core.h>
#include <core/undo/UndoManager.h>
#include <atomviz/utils/muparser/muParser.h>

#include "CreateExpressionChannelModifier.h"

namespace AtomViz {

IMPLEMENT_SERIALIZABLE_PLUGIN_CLASS(CreateExpressionChannelModifier, AtomsObjectModifierBase)
DEFINE_PROPERTY_FIELD(CreateExpressionChannelModifier, "Expressions", _expressions)
DEFINE_PROPERTY_FIELD(CreateExpressionChannelModifier, "DataChannelId", _dataChannelId)
DEFINE_PROPERTY_FIELD(CreateExpressionChannelModifier, "DataChannelName", _dataChannelName)
DEFINE_PROPERTY_FIELD(CreateExpressionChannelModifier, "DataChannelDataType", _dataChannelDataType)
DEFINE_PROPERTY_FIELD(CreateExpressionChannelModifier, "DataChannelVisibility", _dataChannelVisibility)
DEFINE_PROPERTY_FIELD(CreateExpressionChannelModifier, "OnlySelectedAtoms", _onlySelectedAtoms)
SET_PROPERTY_FIELD_LABEL(CreateExpressionChannelModifier, _expressions, "Expressions")
SET_PROPERTY_FIELD_LABEL(CreateExpressionChannelModifier, _dataChannelId, "Channel identifier")
SET_PROPERTY_FIELD_LABEL(CreateExpressionChannelModifier, _dataChannelName, "Channel name")
SET_PROPERTY_FIELD_LABEL(CreateExpressionChannelModifier, _dataChannelDataType, "Data type")
SET_PROPERTY_FIELD_LABEL(CreateExpressionChannelModifier, _dataChannelVisibility, "Show channel")
SET_PROPERTY_FIELD_LABEL(CreateExpressionChannelModifier, _onlySelectedAtoms, "Calculate values only for selected atoms")

/******************************************************************************
* Sets the identifier of the data channel being created by this modifier.
******************************************************************************/
void CreateExpressionChannelModifier::setDataChannelId(DataChannel::DataChannelIdentifier newId)
{
	if(newId == this->dataChannelId()) return;
	this->_dataChannelId = newId;

	if(newId != DataChannel::UserDataChannel) {
		setDataChannelName(DataChannel::standardChannelName(newId));
		setDataChannelDataType(DataChannel::standardChannelType(newId));
		setDataChannelComponentCount(DataChannel::standardChannelComponentCount(newId));
	}
}

/******************************************************************************
* Sets the number of vector components of the data channel to create.
******************************************************************************/
void CreateExpressionChannelModifier::setDataChannelComponentCount(int newComponentCount)
{
	if(newComponentCount == this->dataChannelComponentCount()) return;

	if(newComponentCount < dataChannelComponentCount()) {
		setExpressions(expressions().mid(0, newComponentCount));
	}
	else {
		QStringList newList = expressions();
		while(newList.size() < newComponentCount)
			newList.append("0");
		setExpressions(newList);
	}
}

/******************************************************************************
* Determines the available variable names.
******************************************************************************/
QStringList CreateExpressionChannelModifier::getVariableNames()
{
	QStringList variableNames;

	Q_FOREACH(DataChannel* channel, input()->dataChannels()) {

		// Channels of custom data type are not supported by this modifier.
		if(channel->type() != qMetaTypeId<int>() && channel->type() != qMetaTypeId<FloatType>()) continue;

		// Alter the data channel name to make it a valid variable name for the parser.
		QString variableName = channel->name();
		variableName.remove(QRegExp("[^A-Za-z\\d_]"));
		if(channel->componentNames().empty()) {
			OVITO_ASSERT(channel->componentCount() == 1);
			variableNames << variableName;
		}
		else {
			Q_FOREACH(QString componentName, channel->componentNames()) {
				componentName.remove(QRegExp("[^A-Za-z\\d_]"));
				variableNames << (variableName + "." + componentName);
			}
		}
	}

	if(input()->getStandardDataChannel(DataChannel::AtomIndexChannel) == NULL)
		variableNames << "AtomIndex";

	return variableNames;
}

/**
 * This helper class is needed to enable multi-threaded evaluation of math expressions
 * for all atoms. Each instance of this class is assigned a chunk of atoms that it processes.
 */
class CreateExpressionEvaluationKernel
{
private:
	struct ExpressionVariable {
		double value;
		const char* dataPointer;
		size_t stride;
		bool isFloat;
	};

public:
	/// Initializes the expressions parsers.
	bool initialize(const QStringList& expressions, const QStringList& variableNames, AtomsObject* input, int timestep) {
		parsers.resize(expressions.size());
		variables.resize(variableNames.size());
		bool usesTimeInExpression = false;

		// Compile the expression strings.
		for(int i=0; i<expressions.size(); i++) {

			QString expr = expressions[i];
			if(expr.isEmpty())
				throw Exception(CreateExpressionChannelModifier::tr("The expression for the %1. component is empty.").arg(i+1));

			try {
				// Configure parser to accept '.' in variable names.
				parsers[i].DefineNameChars("0123456789_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.");

				// Add the atan2() function.
				parsers[i].DefineFun("atan2", atan2, false);

				// Let the muParser process the math expression.
				parsers[i].SetExpr(expr.toStdString());

				// Register variables
				for(int v=0; v<variableNames.size(); v++)
					parsers[i].DefineVar(variableNames[v].toStdString(), &variables[v].value);

				// If the current animation time is used in the math expression then we have to
				// reduce the validity interval to the current time only.
				mu::varmap_type usedVariables = parsers[i].GetUsedVar();
				if(usedVariables.find("t") != usedVariables.end())
					usesTimeInExpression = true;

				// Add constants.
				parsers[i].DefineConst("pi", 3.1415926535897932);
				parsers[i].DefineConst("N", input->atomsCount());
				parsers[i].DefineConst("t", timestep);
			}
			catch(mu::Parser::exception_type& ex) {
				throw Exception(QString("%1").arg(QString::fromStdString(ex.GetMsg())));
			}
		}

		// Setup data pointers to input channels.
		size_t vindex = 0;
		Q_FOREACH(DataChannel* channel, input->dataChannels()) {
			if(channel->type() == qMetaTypeId<FloatType>()) {
				for(size_t k=0; k<channel->componentCount(); k++) {
					OVITO_ASSERT(vindex < variableNames.size());
					variables[vindex].dataPointer = reinterpret_cast<const char*>(channel->constDataFloat() + k);
					variables[vindex].stride = channel->perAtomSize();
					variables[vindex].isFloat = true;
					vindex++;
				}
			}
			else if(channel->type() == qMetaTypeId<int>()) {
				for(size_t k=0; k<channel->componentCount(); k++) {
					OVITO_ASSERT(vindex < variableNames.size());
					variables[vindex].dataPointer = reinterpret_cast<const char*>(channel->constDataInt() + k);
					variables[vindex].stride = channel->perAtomSize();
					variables[vindex].isFloat = false;
					vindex++;
				}
			}
			else OVITO_ASSERT(false);
		}

		// Add the special AtomIndex variable if there is no dedicated data channel.
		if(input->getStandardDataChannel(DataChannel::AtomIndexChannel) == NULL) {
			variables[vindex].dataPointer = NULL;
			variables[vindex].stride = 0;
			variables[vindex].isFloat = false;
			vindex++;
		}

		OVITO_ASSERT(vindex == variableNames.size());

		return usesTimeInExpression;
	}

	void run(int startIndex, int endIndex, DataChannel* outputChannel, const int* selectionValues) {
		try {
			// Position pointers.
			if(selectionValues) selectionValues += startIndex;
			for(vector<ExpressionVariable>::iterator v = variables.begin(); v != variables.end(); ++v)
				v->dataPointer += v->stride * startIndex;

			for(int i = startIndex; i < endIndex; i++) {

				// Update variable values for the current atom.
				for(vector<ExpressionVariable>::iterator v = variables.begin(); v != variables.end(); ++v) {
					if(v->isFloat)
						v->value = *reinterpret_cast<const FloatType*>(v->dataPointer);
					else if(v->dataPointer)
						v->value = *reinterpret_cast<const int*>(v->dataPointer);
					else
						v->value = i;
					v->dataPointer += v->stride;
				}

				// Skip unselected atoms if restricted to selected atoms.
				if(selectionValues) {
					if(!(*selectionValues++))
						continue;
				}

				for(int j = 0; j < parsers.size(); j++) {
					// Evaluate expression for the current atom.
					double value = parsers[j].Eval();

					// Store computed value in output channel.
					if(outputChannel->type() == qMetaTypeId<int>())
						outputChannel->setIntComponent(i, j, (int)value);
					else
						outputChannel->setFloatComponent(i, j, (FloatType)value);
				}
			}
		}
		catch(const mu::Parser::exception_type& ex) {
			errorMsg = QString::fromStdString(ex.GetMsg());
		}
	}

	QString errorMsg;
private:
	QVector<mu::Parser> parsers;
	vector<ExpressionVariable> variables;
};

/******************************************************************************
* This modifies the input object.
******************************************************************************/
EvaluationStatus CreateExpressionChannelModifier::modifyAtomsObject(TimeTicks time, TimeInterval& validityInterval)
{
	// Get list of available input variables.
	_variableNames = getVariableNames();

	// Create and initialize the worker threads.
	QVector<CreateExpressionEvaluationKernel> workers(max(QThread::idealThreadCount(), 1));
	for(QVector<CreateExpressionEvaluationKernel>::iterator worker = workers.begin(); worker != workers.end(); ++worker) {
		if(worker->initialize(expressions(), _variableNames, input(), ANIM_MANAGER.timeToFrame(time)))
			validityInterval.intersect(TimeInterval(time));
	}

	// Prepare the deep copy of the output channel.
	DataChannel* outputChannel;
	if(dataChannelId() != DataChannel::UserDataChannel)
		outputChannel = outputStandardChannel(dataChannelId());
	else {
		size_t dataTypeSize;
		if(dataChannelDataType() == qMetaTypeId<int>())
			dataTypeSize = sizeof(int);
		else if(dataChannelDataType() == qMetaTypeId<FloatType>())
			dataTypeSize = sizeof(FloatType);
		else
			throw Exception(tr("New data channel has an invalid data type."));
		outputChannel = output()->createCustomDataChannel(dataChannelDataType(), dataTypeSize, dataChannelComponentCount());
		outputChannel->setName(dataChannelName());
	}
	CHECK_POINTER(outputChannel);
	outputChannel->setVisible(dataChannelVisibility());

	// Get the selection channel if the application of the modifier is restricted to selected atoms.
	const int* selectionValues = NULL;
	if(_onlySelectedAtoms) {
		DataChannel* selChannel = input()->getStandardDataChannel(DataChannel::SelectionChannel);
		if(!selChannel) {
			throw Exception(tr("Evaluation has been restricted to selected atoms but input object does not contain a selection channel."));
		}
		selectionValues = selChannel->constDataInt();
	}

	// This call is necessary to deep copy the memory array of the output channel before accessing it from multiple threads.
	outputChannel->data();

	// Spawn worker threads.
	QFutureSynchronizer<void> synchronizer;
	int chunkSize = max((int)input()->atomsCount() / workers.size(), 1);
	for(int i = 0; i < workers.size(); i++) {
		// Setup data range.
		int startIndex = i * chunkSize;
		int endIndex = min((i+1) * chunkSize, (int)input()->atomsCount());
		if(i == workers.size() - 1) endIndex = input()->atomsCount();
		if(endIndex <= startIndex) continue;

		synchronizer.addFuture(QtConcurrent::run(&workers[i], &CreateExpressionEvaluationKernel::run, startIndex, endIndex, outputChannel, selectionValues));
	}
	synchronizer.waitForFinished();

	// Check for errors.
	for(int i = 0; i < workers.size(); i++) {
		if(workers[i].errorMsg.isEmpty() == false)
			throw Exception(workers[i].errorMsg);
	}

	return EvaluationStatus();
}

IMPLEMENT_PLUGIN_CLASS(CreateExpressionChannelModifierEditor, AtomsObjectModifierEditorBase)

/******************************************************************************
* Sets up the UI widgets of the editor.
******************************************************************************/
void CreateExpressionChannelModifierEditor::createUI(const RolloutInsertionParameters& rolloutParams)
{
	QWidget* rollout = createRollout(tr("Create Expression Channel"), rolloutParams, "atomviz.modifiers.add_expression_channel");

    // Create the rollout contents.
	QVBoxLayout* mainLayout = new QVBoxLayout(rollout);
	mainLayout->setContentsMargins(4,4,4,4);

	QGroupBox* channelPropertiesGroupBox = new QGroupBox(tr("Data channel properties"));
	mainLayout->addWidget(channelPropertiesGroupBox);
	QGridLayout* channelPropertiesLayout = new QGridLayout(channelPropertiesGroupBox);
	channelPropertiesLayout->setContentsMargins(4,4,4,4);
	channelPropertiesLayout->setColumnStretch(1, 1);
	channelPropertiesLayout->setSpacing(2);

	// Create the combo box with the standard data channel identifiers.
	VariantComboBoxPropertyUI* dataChannelIdUI = new VariantComboBoxPropertyUI(this, "dataChannelId");
	channelPropertiesLayout->addWidget(new QLabel(tr("Data channel to create:")), 0, 0);
	channelPropertiesLayout->addWidget(dataChannelIdUI->comboBox(), 0, 1, 1, 2);
	QMap<QString, DataChannel::DataChannelIdentifier> standardChannels = DataChannel::standardChannelList();
	dataChannelIdUI->comboBox()->addItem(tr("Custom"), DataChannel::UserDataChannel);
	for(QMap<QString, DataChannel::DataChannelIdentifier>::const_iterator i = standardChannels.constBegin(); i != standardChannels.constEnd(); ++i) {
		dataChannelIdUI->comboBox()->addItem(i.key(), i.value());
	}

	// Create the field with the data channel name.
	dataChannelNameUI = new StringPropertyUI(this, "dataChannelName");
	channelPropertiesLayout->addWidget(new QLabel(tr("Name:")), 1, 0);
	channelPropertiesLayout->addWidget(dataChannelNameUI->textBox(), 1, 1, 1, 2);

	// Create the combo box with the data channel types.
	dataChannelTypeUI = new VariantComboBoxPropertyUI(this, "dataChannelDataType");
	channelPropertiesLayout->addWidget(new QLabel(tr("Data type:")), 2, 0);
	channelPropertiesLayout->addWidget(dataChannelTypeUI->comboBox(), 2, 1, 1, 2);
	dataChannelTypeUI->comboBox()->addItem(tr("Float"), qMetaTypeId<FloatType>());
	dataChannelTypeUI->comboBox()->addItem(tr("Integer"), qMetaTypeId<int>());

	// Create the spinner for the number of components.
	numComponentsUI = new IntegerPropertyUI(this, "dataChannelComponentCount");
	numComponentsUI->setMinValue(1);
	numComponentsUI->setMaxValue(64);
	channelPropertiesLayout->addWidget(new QLabel(tr("Number of components:")), 3, 0);
	channelPropertiesLayout->addWidget(numComponentsUI->textBox(), 3, 1);
	channelPropertiesLayout->addWidget(numComponentsUI->spinner(), 3, 2);

	// Create the check box for the channel visibility flag.
	BooleanPropertyUI* visibilityFlagUI = new BooleanPropertyUI(this, PROPERTY_FIELD_DESCRIPTOR(CreateExpressionChannelModifier, _dataChannelVisibility));
	channelPropertiesLayout->addWidget(visibilityFlagUI->checkBox(), 4, 0, 1, 3);

	// Create the check box for the selection flag.
	BooleanPropertyUI* selectionFlagUI = new BooleanPropertyUI(this, PROPERTY_FIELD_DESCRIPTOR(CreateExpressionChannelModifier, _onlySelectedAtoms));
	channelPropertiesLayout->addWidget(selectionFlagUI->checkBox(), 5, 0, 1, 3);

	QGroupBox* expressionsGroupBox = new QGroupBox(tr("Expressions"));
	mainLayout->addWidget(expressionsGroupBox);
	expressionsLayout = new QVBoxLayout(expressionsGroupBox);
	expressionsLayout->setContentsMargins(4,4,4,4);
	expressionsLayout->setSpacing(0);

	// Status label.
	mainLayout->addWidget(statusLabel());

	QWidget* variablesRollout = createRollout(tr("Variables"), rolloutParams.after(rollout), "atomviz.modifiers.add_expression_channel");
    QVBoxLayout* variablesLayout = new QVBoxLayout(variablesRollout);
    variablesLayout->setContentsMargins(4,4,4,4);
	variableNamesList = new QLabel();
	variableNamesList->setWordWrap(true);
	variableNamesList->setTextInteractionFlags(Qt::TextSelectableByMouse | Qt::TextSelectableByKeyboard | Qt::LinksAccessibleByMouse | Qt::LinksAccessibleByKeyboard);
	variablesLayout->addWidget(variableNamesList);
}

/******************************************************************************
* This method is called when a reference target changes.
******************************************************************************/
bool CreateExpressionChannelModifierEditor::onRefTargetMessage(RefTarget* source, RefTargetMessage* msg)
{
	if(source == editObject() && msg->type() == REFTARGET_CHANGED) {
		updateEditorFields();
	}
	/*
	else if(msg->sender() == editObject() && msg->type() == MODIFIER_EVALUATION_MESSAGE_UPDATE) {
		updateEditorFields();
	}
	*/
	return AtomsObjectModifierEditorBase::onRefTargetMessage(source, msg);
}

/******************************************************************************
* Updates the enabled/disabled status of the editor's controls.
******************************************************************************/
void CreateExpressionChannelModifierEditor::updateEditorFields()
{
	CreateExpressionChannelModifier* mod = static_object_cast<CreateExpressionChannelModifier>(editObject());
	dataChannelNameUI->setEnabled(mod && mod->dataChannelId() == DataChannel::UserDataChannel);
	dataChannelTypeUI->setEnabled(mod && mod->dataChannelId() == DataChannel::UserDataChannel);
	numComponentsUI->setEnabled(mod && mod->dataChannelId() == DataChannel::UserDataChannel);
	if(!mod) return;

	const QStringList& expr = mod->expressions();
	while(expr.size() > expressionBoxes.size()) {
		QLabel* label = new QLabel(container());
		QLineEdit* edit = new QLineEdit(container());
		expressionsLayout->insertWidget(expressionBoxes.size()*2, label);
		expressionsLayout->insertWidget(expressionBoxes.size()*2 + 1, edit);
		expressionBoxes.push_back(edit);
		expressionBoxLabels.push_back(label);
		connect(edit, SIGNAL(editingFinished()), this, SLOT(onExpressionEditingFinished()));
	}
	while(expr.size() < expressionBoxes.size()) {
		delete expressionBoxes.takeLast();
		delete expressionBoxLabels.takeLast();
	}
	OVITO_ASSERT(expressionBoxes.size() == expr.size());
	OVITO_ASSERT(expressionBoxLabels.size() == expr.size());

	QStringList standardChannelComponentNames;
	if(mod->dataChannelId() != DataChannel::UserDataChannel) {
		standardChannelComponentNames = DataChannel::standardChannelComponentNames(mod->dataChannelId());
		if(standardChannelComponentNames.empty())
			standardChannelComponentNames.push_back(DataChannel::standardChannelName(mod->dataChannelId()));
	}
	for(int i=0; i<expr.size(); i++) {
		expressionBoxes[i]->setText(expr[i]);
		if(i < standardChannelComponentNames.size())
			expressionBoxLabels[i]->setText(tr("%1:").arg(standardChannelComponentNames[i]));
		else
			expressionBoxLabels[i]->setText(tr("Component %1:").arg(i+1));
	}

	QString labelText(tr("The following variables can be used in the math expressions:<ul>"));
	Q_FOREACH(QString s, mod->lastVariableNames()) {
		labelText.append(QString("<li>%1</li>").arg(s));
	}
	labelText.append(QString("<li>N (number of atoms)</li>"));
	labelText.append(QString("<li>t (current time frame)</li>"));
	labelText.append("</ul><p></p>");
	variableNamesList->setText(labelText);
}

/******************************************************************************
* Is called when the user has typed in an expression.
******************************************************************************/
void CreateExpressionChannelModifierEditor::onExpressionEditingFinished()
{
	QLineEdit* edit = (QLineEdit*)sender();
	int index = expressionBoxes.indexOf(edit);
	OVITO_ASSERT(index >= 0);

	CreateExpressionChannelModifier* mod = static_object_cast<CreateExpressionChannelModifier>(editObject());
	QStringList expr = mod->expressions();
	expr[index] = edit->text();

	UNDO_MANAGER.beginCompoundOperation(tr("Change Expression"));
	mod->setExpressions(expr);
	UNDO_MANAGER.endCompoundOperation();
}

};	// End of namespace AtomViz
