# Copyright (C) 2010 Marie E. Rognes
#
# This file is part of FFC.
#
# FFC is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# FFC is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with FFC.  If not, see <http://www.gnu.org/licenses/>.
#
# Last changed: 2011-01-24

from ufl import Coefficient, action, FiniteElement, TestFunction, derivative
from ufl.algorithms.analysis import extract_arguments
from ufl.algorithms.elementtransformations import increase_order, tear

from ffc.log import info, error
from ffc.compiler import compile_form

from ffc.errorcontrol.formmanipulations import *

__all__ = ["compile_with_error_control"]

def _check_input(forms, object_names):
    """
    Can handle three variants of forms:

    (a, L, M): a has rank 2, L has rank a and M has rank 1

    (F, M): F has rank 1 and M has rank 0
    // (F, dF, M): F has rank 1, dF has rank 2 and M has rank 0
    """

    # Extract unknown (of None if not defined)
    unknown = object_names.get("unknown", None)

    if len(forms) == 2:
        (F, M) = forms

        #Check that unknown is defined
        assert (unknown), "Variable 'unknown' must be defined!"

        # Generate jacobian of F
        dF = derivative(F, unknown)
        return (dF, F, M), unknown, True

    assert(len(forms) == 3), "Wrong form input."

    (a, L, M) = forms

    # If unknown is undefined, define discrete solution as coefficient
    # on trial element and make M a functional (instead of a linear form)
    if unknown is None:
        V = extract_arguments(a)[1].element()
        unknown = Coefficient(V)
        object_names[id(unknown)] = "__discrete_primal_solution"
        M = action(M, unknown)
    else:
        error("Not implemented!")

    return (a, L, M), unknown, False

def generate_error_control(forms, object_names, module=None):

    info("Generating additionals")

    if module is None:
        module = __import__("ufl")

    # Check input and extract appropriate forms
    forms, unknown, nonlinear = _check_input(forms, object_names)

    # Generate dual forms
    a_star, L_star = generate_dual_forms(forms, unknown, module)

    # Extract trial element as second argument of bilinear form
    V = unknown.element()

    # Generate extrapolation space by increasing order of trial space
    print "increase_order = ", increase_order
    E = increase_order(V)

    # Dictionary for storing object names generated by error control
    ec_names = {}

    # Create coefficient for improved dual
    Ez_h = Coefficient(E)
    ec_names[id(Ez_h)] = "__improved_dual"

    # Create weak residual
    if nonlinear:
        weak_residual = generate_weak_residual(forms[1])
    else:
        weak_residual = generate_weak_residual(forms[:-1], unknown)

    # Generate error estimate (residual) (# FIXME: Add option here)
    eta_h = action(weak_residual, Ez_h)

    # Define approximation space for cell and facet residuals
    V_h = tear(V)

    # Define bubble
    B = FiniteElement("B", V.cell(), V.cell().geometric_dimension()+1)
    b_T = Coefficient(B)
    ec_names[id(b_T)] = "__cell_bubble"

    # Create cell residual forms
    a_R_T, L_R_T = generate_cell_residual(weak_residual, V_h, b_T, module)

    # Define coefficient for cell residual
    R_T = Coefficient(V_h)
    ec_names[id(R_T)] = "__cell_residual"

    # Establish cone function(s)
    C = FiniteElement("DG", V.cell(), V.cell().geometric_dimension())
    b_e = Coefficient(C)
    ec_names[id(b_e)] = "__cell_cone"

    # Create facet residual forms
    a_R_dT, L_R_dT = generate_facet_residual(weak_residual, V_h, b_e, R_T,
                                             module)
    # Define
    R_dT = Coefficient(V_h)
    z_h = Coefficient(extract_arguments(a_star)[1].element())
    ec_names[id(R_dT)] = "__facet_residual"
    ec_names[id(z_h)] = "__discrete_dual_solution"

    # Generate error indicators (# FIXME: Add option here)
    v = TestFunction(FiniteElement("DG", V.cell(), 0))
    eta_T = generate_error_indicator(weak_residual, R_T, R_dT, Ez_h, z_h, v)

    ec_forms = (a_star, L_star, a_R_T, L_R_T, a_R_dT, L_R_dT, eta_h, eta_T)

    return (ec_forms, ec_names, forms, not nonlinear)

def compile_with_error_control(forms, object_names, prefix, parameters):

    # Generate additional forms (and names) for error control
    ec_forms, ec_names, forms, linear = generate_error_control(forms, object_names)

    # Check that there are no conflicts between user defined and
    # generated names
    comment = "%s are reserved error control names." % str(ec_names.keys())
    assert not (set(object_names.values()) & set(ec_names.values())), \
               "Conflict between user defined and generated names: %s" % comment
    for (name, value) in ec_names.iteritems():
        object_names[name] = value

    # Compile error control and input (pde + goal) forms as normal
    compile_form(ec_forms + forms, object_names, prefix, parameters)

    return 0
