// Copyright (C) 2006-2009 Kent-Andre Mardal and Simula Research Laboratory.
// Licensed under the GNU GPL Version 2, or (at your option) any later version.

#include "Hermite.h"
#include "tools.h"

using std::cout;
using std::endl;
using std::string;

namespace SyFi
{

	Hermite:: Hermite() : StandardFE()
	{
		description = "Hermite";
	}

	Hermite:: Hermite(Polygon& p, int order) : StandardFE(p,order)
	{
		compute_basis_functions();
	}

	void Hermite:: compute_basis_functions()
	{

		// remove previously computed basis functions and dofs
		Ns.clear();
		dofs.clear();

		if ( p == NULL )
		{
			throw(std::logic_error("You need to set a polygon before the basisfunctions can be computed"));
		}

		GiNaC::ex polynom_space;
		GiNaC::ex polynom;
		GiNaC::lst variables;
		GiNaC::lst equations;

		if ( p->str().find("Line") != string::npos )
		{

			description = "Hermite_1D";

			polynom_space = legendre(3, 1, "a");
			polynom = polynom_space.op(0);
			variables = GiNaC::ex_to<GiNaC::lst>(polynom_space.op(1));

			for (int i=0; i< 2; i++)
			{
				GiNaC::ex v = p->vertex(i);
				GiNaC::ex dofv   = polynom.subs(GiNaC::lst(x == v.op(0)));
				GiNaC::ex dofvdx = diff(polynom,x).subs(GiNaC::lst(x == v.op(0)));
				equations.append( dofv   == GiNaC::numeric(0));
				equations.append( dofvdx == GiNaC::numeric(0));

				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), 0));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), 1));
			}

		}

		if ( p->str().find("Triangle") != string::npos )
		{

			description = "Hermite_2D";

			polynom_space = pol(3, 2, "a");
			polynom = polynom_space.op(0);
			variables = GiNaC::ex_to<GiNaC::lst>(polynom_space.op(1));

			for (int i=0; i<= 2; i++)
			{
				GiNaC::ex v = p->vertex(i);
				GiNaC::ex dofv = polynom.subs(GiNaC::lst(x == v.op(0), y == v.op(1)));
				GiNaC::ex dofvdx = diff(polynom,x).subs(GiNaC::lst(x == v.op(0), y == v.op(1)));
				GiNaC::ex dofvdy = diff(polynom,y).subs(GiNaC::lst(x == v.op(0), y == v.op(1)));

				equations.append( dofv   == GiNaC::numeric(0));
				equations.append( dofvdx == GiNaC::numeric(0));
				equations.append( dofvdy == GiNaC::numeric(0));

				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 0));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 1));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 2));
			}
			GiNaC::ex midpoint = GiNaC::lst((p->vertex(0).op(0) + p->vertex(1).op(0) + p->vertex(2).op(0))/3,
				(p->vertex(0).op(1) + p->vertex(1).op(1) + p->vertex(2).op(1))/3);
			GiNaC::ex dofm = polynom.subs(GiNaC::lst(x == midpoint.op(0), y == midpoint.op(1)));
			dofs.insert(dofs.end(), midpoint );
			equations.append( dofm == GiNaC::numeric(0));

		}

		else if ( p->str().find("Rectangle") != string::npos )
		{

			description = "Hermite_2D";

			polynom_space = legendre(3, 2, "a");
			polynom = polynom_space.op(0);
			variables = GiNaC::ex_to<GiNaC::lst>(polynom_space.op(1));

			for (int i=0; i< 4; i++)
			{
				GiNaC::ex v = p->vertex(i);
				GiNaC::ex dofv   = polynom.subs(GiNaC::lst(x == v.op(0), y == v.op(1)));
				GiNaC::ex dofvdx = diff(polynom,x).subs(GiNaC::lst(x == v.op(0), y == v.op(1)));
				GiNaC::ex dofvdy = diff(polynom,y).subs(GiNaC::lst(x == v.op(0), y == v.op(1)));
				GiNaC::ex dofvdyx = diff(diff(polynom,y),x).subs(GiNaC::lst(x == v.op(0), y == v.op(1)));
				equations.append( dofv   == GiNaC::numeric(0));
				equations.append( dofvdx == GiNaC::numeric(0));
				equations.append( dofvdy == GiNaC::numeric(0));
				equations.append( dofvdyx == GiNaC::numeric(0));

				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 0));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 1));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 2));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 3));
			}

		}
		else if ( p->str().find("Tetrahedron") != string::npos )
		{

			description = "Hermite_3D";

			polynom_space = pol(3, 3, "a");
			polynom = polynom_space.op(0);
			variables = GiNaC::ex_to<GiNaC::lst>(polynom_space.op(1));

			for (int i=0; i<= 3; i++)
			{
				GiNaC::ex v = p->vertex(i);
				GiNaC::ex dofv   = polynom.subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdx = diff(polynom,x).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdy = diff(polynom,y).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdz = diff(polynom,z).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));

				equations.append( dofv   == GiNaC::numeric(0));
				equations.append( dofvdx == GiNaC::numeric(0));
				equations.append( dofvdy == GiNaC::numeric(0));
				equations.append( dofvdz == GiNaC::numeric(0));

				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 0));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 1));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 2));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), 3));

			}
			GiNaC::ex midpoint1 = GiNaC::lst(
				(p->vertex(0).op(0)*2 + p->vertex(1).op(0) + p->vertex(2).op(0) + p->vertex(3).op(0))/5,
				(p->vertex(0).op(1)*2 + p->vertex(1).op(1) + p->vertex(2).op(1) + p->vertex(3).op(1))/5,
				(p->vertex(0).op(2)*2 + p->vertex(1).op(2) + p->vertex(2).op(2) + p->vertex(3).op(2))/5);

			GiNaC::ex midpoint2 = GiNaC::lst(
				(p->vertex(0).op(0) + p->vertex(1).op(0)*2 + p->vertex(2).op(0) + p->vertex(3).op(0))/5,
				(p->vertex(0).op(1) + p->vertex(1).op(1)*2 + p->vertex(2).op(1) + p->vertex(3).op(1))/5,
				(p->vertex(0).op(2) + p->vertex(1).op(2)*2 + p->vertex(2).op(2) + p->vertex(3).op(2))/5);

			GiNaC::ex midpoint3 = GiNaC::lst(
				(p->vertex(0).op(0) + p->vertex(1).op(0) + p->vertex(2).op(0)*2 + p->vertex(3).op(0))/5,
				(p->vertex(0).op(1) + p->vertex(1).op(1) + p->vertex(2).op(1)*2 + p->vertex(3).op(1))/5,
				(p->vertex(0).op(2) + p->vertex(1).op(2) + p->vertex(2).op(2)*2 + p->vertex(3).op(2))/5);

			GiNaC::ex midpoint4 = GiNaC::lst(
				(p->vertex(0).op(0) + p->vertex(1).op(0) + p->vertex(2).op(0) + p->vertex(3).op(0)*2)/5,
				(p->vertex(0).op(1) + p->vertex(1).op(1) + p->vertex(2).op(1) + p->vertex(3).op(1)*2)/5,
				(p->vertex(0).op(2) + p->vertex(1).op(2) + p->vertex(2).op(2) + p->vertex(3).op(2)*2)/5);

			GiNaC::ex dofm1 = polynom.subs(GiNaC::lst(x == midpoint1.op(0), y == midpoint1.op(1), z == midpoint1.op(2)));
			GiNaC::ex dofm2 = polynom.subs(GiNaC::lst(x == midpoint2.op(0), y == midpoint2.op(1), z == midpoint2.op(2)));
			GiNaC::ex dofm3 = polynom.subs(GiNaC::lst(x == midpoint3.op(0), y == midpoint3.op(1), z == midpoint3.op(2)));
			GiNaC::ex dofm4 = polynom.subs(GiNaC::lst(x == midpoint4.op(0), y == midpoint4.op(1), z == midpoint4.op(2)));

			dofs.insert(dofs.end(), midpoint1 );
			dofs.insert(dofs.end(), midpoint2 );
			dofs.insert(dofs.end(), midpoint3 );
			dofs.insert(dofs.end(), midpoint4 );

			equations.append( dofm1 == GiNaC::numeric(0));
			equations.append( dofm2 == GiNaC::numeric(0));
			equations.append( dofm3 == GiNaC::numeric(0));
			equations.append( dofm4 == GiNaC::numeric(0));

		}
		else if ( p->str().find("Box") != string::npos )
		{

			description = "Hermite_3D";

			polynom_space = legendre(3, 3, "a");
			polynom = polynom_space.op(0);
			variables = GiNaC::ex_to<GiNaC::lst>(polynom_space.op(1));

			for (int i=0; i<= 7; i++)
			{
				GiNaC::ex v = p->vertex(i);
				GiNaC::ex dofv   = polynom.subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdx = diff(polynom,x).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdy = diff(polynom,y).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdyx = diff(diff(polynom,y),x).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdz = diff(polynom,z).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdzy = diff(diff(polynom,z),y).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdzx = diff(diff(polynom,z),x).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));
				GiNaC::ex dofvdzyx = diff(diff(diff(polynom,z),y),x).subs(GiNaC::lst(x == v.op(0), y == v.op(1), z == v.op(2) ));

				equations.append( dofv   == GiNaC::numeric(0));
				equations.append( dofvdx == GiNaC::numeric(0));
				equations.append( dofvdy == GiNaC::numeric(0));
				equations.append( dofvdyx == GiNaC::numeric(0));
				equations.append( dofvdz == GiNaC::numeric(0));
								 // FIXME check Andrew/Ola numbering
				equations.append( dofvdzy == GiNaC::numeric(0));
								 // FIXME check Andrew/Ola numbering
				equations.append( dofvdzx == GiNaC::numeric(0));
				equations.append( dofvdzyx == GiNaC::numeric(0));

				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 0));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 1));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 2));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 3));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 4));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 5));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 6));
				dofs.insert(dofs.end(), GiNaC::lst(v.op(0), v.op(1), v.op(2), 7));

			}

		}

		GiNaC::matrix b; GiNaC::matrix A;
		matrix_from_equations(equations, variables, A, b);

		unsigned int ncols = A.cols();
		GiNaC::matrix vars_sq(ncols, ncols);

		// matrix of symbols
		for (unsigned r=0; r<ncols; ++r)
			for (unsigned c=0; c<ncols; ++c)
				vars_sq(r, c) = GiNaC::symbol();

		GiNaC::matrix id(ncols, ncols);

		// identity
		const GiNaC::ex _ex1(1);
		for (unsigned i=0; i<ncols; ++i)
			id(i, i) = _ex1;

		// invert the matrix
		GiNaC::matrix m_inv(ncols, ncols);
		m_inv = A.solve(vars_sq, id, GiNaC::solve_algo::gauss);

		for (unsigned int i=0; i<dofs.size(); i++)
		{
			b.let_op(i) = GiNaC::numeric(1);
			GiNaC::ex xx = m_inv.mul(GiNaC::ex_to<GiNaC::matrix>(b));

			GiNaC::lst subs;
			for (unsigned int ii=0; ii<xx.nops(); ii++)
			{
				subs.append(variables.op(ii) == xx.op(ii));
			}
			GiNaC::ex Nj= polynom.subs(subs).expand();
			Ns.insert(Ns.end(), Nj);
			b.let_op(i) = GiNaC::numeric(0);
		}

	}

}								 // namespace SyFi
