"""This module provides a small Python layer on top of the C++
VariationalProblem/Solver classes as well as the solve function."""

# Copyright (C) 2011 Anders Logg
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Marie E. Rognes, 2011.
# Modified by Johan Hake, 2011.
#
# First added:  2011-06-22
# Last changed: 2011-07-04

__all__ = ["LinearVariationalProblem",
           "LinearVariationalSolver",
           "NonlinearVariationalProblem",
           "NonlinearVariationalSolver",
           "solve"]

# Import SWIG-generated extension module (DOLFIN C++)
import dolfin.cpp as cpp

# Import UFL
import ufl

# Local imports
from dolfin.fem.form import Form
from dolfin.fem.formmanipulations import derivative

# Problem classes need special handling since they involve JIT compilation

class LinearVariationalProblem(cpp.LinearVariationalProblem):

    # Reuse C++ doc-string
    __doc__ = cpp.LinearVariationalProblem.__doc__

    def __init__(self, a, L, u, bcs=None,
                 form_compiler_parameters=None):
        """
        Create linear variational problem a(u, v) = L(v).

        An optional argument bcs may be passed to specify boundary
        conditions.

        Another optional argument form_compiler_parameters may be
        specified to pass parameters to the form compiler.
        """

        # Extract and check arguments
        u = _extract_u(u)
        bcs = _extract_bcs(bcs)

        # Store input UFL forms and solution Function
        self.a_ufl = a
        self.L_ufl = L
        self.u_ufl = u

        # Store form compiler parameters
        form_compiler_parameters = form_compiler_parameters or {}
        self.form_compiler_parameters = form_compiler_parameters

        # Wrap forms
        a = Form(a, form_compiler_parameters=form_compiler_parameters)
        L = Form(L, form_compiler_parameters=form_compiler_parameters)

        # Initialize C++ base class
        cpp.LinearVariationalProblem.__init__(self, a, L, u, bcs)

class NonlinearVariationalProblem(cpp.NonlinearVariationalProblem):

    # Reuse C++ doc-string
    __doc__ = cpp.NonlinearVariationalProblem.__doc__

    def __init__(self, F, u, bcs=None, J=None,
                 form_compiler_parameters=None):
        """
        Create nonlinear variational problem F(u; v) = 0.

        Optional arguments bcs and J may be passed to specify boundary
        conditions and the Jacobian J = dF/du.

        Another optional argument form_compiler_parameters may be
        specified to pass parameters to the form compiler.
        """

        # Extract and check arguments
        u = _extract_u(u)
        bcs = _extract_bcs(bcs)

        # Store input UFL forms and solution Function
        self.F_ufl = F
        self.J_ufl = J
        self.u_ufl = u

        # Store form compiler parameters
        form_compiler_parameters = form_compiler_parameters or {}
        self.form_compiler_parameters = form_compiler_parameters

        # Wrap forms
        F = Form(F, form_compiler_parameters=form_compiler_parameters)
        if J is not None:
            J = Form(J, form_compiler_parameters=form_compiler_parameters)

        # Initialize C++ base class
        if J is None:
            cpp.NonlinearVariationalProblem.__init__(self, F, u, bcs)
        else:
            cpp.NonlinearVariationalProblem.__init__(self, F, u, bcs, J)

# Solver classes are imported directly
from dolfin.cpp import LinearVariationalSolver, NonlinearVariationalSolver
from dolfin.fem.adaptivesolving import AdaptiveLinearVariationalSolver
from dolfin.fem.adaptivesolving import AdaptiveNonlinearVariationalSolver

# Solve function handles both linear systems and variational problems

def solve(*args, **kwargs):
    """Solve linear system Ax = b or variational problem a == L or F == 0.

    The DOLFIN solve() function can be used to solve either linear
    systems or variational problems. The following list explains the
    various ways in which the solve() function can be used.

    *1. Solving linear systems*

    A linear system Ax = b may be solved by calling solve(A, x, b),
    where A is a matrix and x and b are vectors. Optional arguments
    may be passed to specify the solver type and preconditioer
    type. Some examples are given below:

    .. code-block:: python

        solve(A, x, b)
        solve(A, x, b, "lu")
        solve(A, x, b, "gmres", "ilu")
        solve(A, x, b, "cg", "hypre_amg")

    Possible values for the solver type depend on which linear algebra
    backend is used and how that has been configured, but possible
    values may include:

        =============       ============================================
        Usage               Method
        =============       ============================================
        "lu"                LU factorization
        "cholesky"          Cholesky factorization
        "cg"                Conjugate gradient method
        "gmres"             Generalized minimal residual method
        "bicgstab"          Biconjugate gradient stabilized method
        "minres"            Minimal residual method
        "tfqmr"             Transpose-free quasi-minimal residual method
        "richardson"        Richardson method
        =============       ============================================

    Similarly, possible values for the preconditioner type may include:

        ==================  ============================================
        Usage               Preconditioner
        ==================  ============================================
        "none"              No preconditioner
        "ilu"               Incomplete LU factorization
        "icc"               Incomplete Cholesky factorization
        "jacobi"            Jacobi iteration
        "bjacobi"           Block Jacobi iteration
        "sor"               Successive over-relaxation
        "amg"               Algebraic multigrid (either BoomerAMG or ML)
        "additive_schwarz"  Additive Schwarz
        "hypre_amg"         Hypre algebraic multigrid (BoomerAMG)
        "hypre_euclid"      Hypre parallel incomplete LU factorization
        "hypre_parasails"   Hypre parallel sparse approximate inverse
        "ml_amg"            ML algebraic multigrid
        ==================  ============================================

    *2. Solving linear variational problems*

    A linear variational problem a(u, v) = L(v) for all v may be
    solved by calling solve(a == L, u, ...), where a is a bilinear
    form, L is a linear form, u is a Function (the solution). Optional
    arguments may be supplied to specify boundary conditions or solver
    parameters. Some examples are given below:

    .. code-block:: python

        solve(a == L, u)
        solve(a == L, u, bc)
        solve(a == L, u, [bc1, bc2])

        solve(a == L, u, bcs,
              solver_parameters={"linear_solver": "lu"},
              form_compiler_parameters={"optimize": True})

    *3. Solving nonlinear variational problems*

    A nonlinear variational problem F(u; v) = 0 for all v may be
    solved by calling solve(F == 0, u, ...), where the residual F is a
    linear form (linear in the test function v but possibly nonlinear
    in the unknown u) and u is a Function (the solution). Optional
    arguments may be supplied to specify boundary conditions, the
    Jacobian form or solver parameters. If the Jacobian is not
    supplied, it will be computed by automatic differentiation of the
    residual form. Some examples are given below:

    .. code-block:: python

        solve(F == 0, u)
        solve(F == 0, u, bc)
        solve(F == 0, u, [bc1, bc2])

        solve(F == 0, u, bcs, J=J,
              solver_parameters={"linear_solver": "lu"},
              form_compiler_parameters={"optimize": True})

    *4. Solving linear/nonlinear variational problems adaptively

    FIXME: Missing documentation for adaptive solve, add here.
    """

    assert(len(args) > 0)

    # Call adaptive solve if we get a tolerance
    if "tol" in kwargs:
        _solve_varproblem_adaptive(*args, **kwargs)

    # Call variational problem solver if we get an equation (but not a tolerance)
    elif isinstance(args[0], ufl.classes.Equation):
        _solve_varproblem(*args, **kwargs)

    # Default case, just call the wrapped C++ solve function
    else:
        return cpp.solve(*args)

def _solve_varproblem(*args, **kwargs):
    "Solve variational problem a == L or F == 0"

    # Extract arguments
    eq, u, bcs, J, tol, M = _extract_args(*args, **kwargs)
    form_compiler_parameters = kwargs.get("form_compiler_parameters", {})
    solver_parameters = kwargs.get("solver_parameters", {})

    # Solve linear variational problem
    if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):

        # Create problem
        problem = LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs,
                    form_compiler_parameters=form_compiler_parameters)

        # Create solver and call solve
        solver = LinearVariationalSolver(problem)
        solver.parameters.update(solver_parameters)
        solver.solve()

    # Solve nonlinear variational problem
    else:

        # Create Jacobian if missing
        if J is None:
            cpp.info("No Jacobian form specified for nonlinear variational problem.")
            cpp.info("Differentiating residual form F to obtain Jacobian J = F'.")
            F = eq.lhs
            J = derivative(F, u)

        # Create problem
        problem = NonlinearVariationalProblem(eq.lhs, u, bcs, J,
                    form_compiler_parameters=form_compiler_parameters)

        # Create solver and call solve
        solver = NonlinearVariationalSolver(problem)
        solver.parameters.update(solver_parameters)
        solver.solve()

def _solve_varproblem_adaptive(*args, **kwargs):
    "Solve variational problem a == L or F == 0 adaptively"

    # Extract arguments
    eq, u, bcs, J, tol, M = _extract_args(*args, **kwargs)
    form_compiler_parameters = kwargs.get("form_compiler_parameters", {})
    solver_parameters = kwargs.get("solver_parameters", {})

    # Solve linear variational problem
    if isinstance(eq.lhs, ufl.Form) and isinstance(eq.rhs, ufl.Form):

        # Create problem
        problem = LinearVariationalProblem(eq.lhs, eq.rhs, u, bcs,
                    form_compiler_parameters=form_compiler_parameters)

        # Create solver and call solve
        solver = AdaptiveLinearVariationalSolver(problem)
        solver.parameters.update(solver_parameters)
        solver.solve(tol, M)

    # Solve nonlinear variational problem
    else:

        # Create Jacobian if missing
        if J is None:
            cpp.info("No Jacobian form specified for nonlinear variational problem.")
            cpp.info("Differentiating residual form F to obtain Jacobian J = F'.")
            F = eq.lhs
            J = derivative(F, u)

        # Create problem
        problem = NonlinearVariationalProblem(eq.lhs, u, bcs, J,
                    form_compiler_parameters=form_compiler_parameters)

        # Create solver and call solve
        solver = AdaptiveNonlinearVariationalSolver(problem)
        solver.parameters.update(solver_parameters)
        solver.solve(tol, M)

def _extract_args(*args, **kwargs):
    "Common extraction of arguments for _solve_varproblem[_adaptive]"

    # Extract equation
    if not len(args) >= 2:
        cpp.dolfin_error("solving.py",
                         "solve variational problem",
                         "Missing arguments, expecting solve(lhs == rhs, u, ...)")

    # Extract equation
    eq = _extract_eq(args[0])

    # Extract solution function
    u = _extract_u(args[1])

    # Extract boundary conditions
    if len(args) == 2:
        bcs = []
    else:
        bcs = _extract_bcs(args[2])

    # Extract Jacobian
    J = kwargs.get("J", None)
    if J is not None and not isinstance(J, ufl.Form):
        cpp.dolfin_error("solving.py",
                         "solve variational problem",
                         "Expecting Jacobian J to be a UFL Form")

    # Extract tolerance
    tol = kwargs.get("tol", None)
    if tol is not None and not (isinstance(tol, (float, int)) and tol > 0.0):
        cpp.dolfin_error("solving.py",
                         "solve variational problem",
                         "Expecting tolerance tol to be a positive number")

    # Extract functional
    M = kwargs.get("M", None)
    if M is not None and not isinstance(M, ufl.Form):
        cpp.dolfin_error("solving.py",
                         "solve variational problem",
                         "Expecting goal functional M to be a UFL Form")

    return eq, u, bcs, J, tol, M

def _extract_eq(eq):
    "Extract and check argument eq"
    if not isinstance(eq, ufl.classes.Equation):
        cpp.dolfin_error("solving.py",
                         "solve variational problem",
                         "Expecting first argument to be an Equation")
    return eq

def _extract_u(u):
    "Extract and check argument u"
    if not isinstance(u, cpp.Function):
        cpp.dolfin_error("solving.py",
                         "solve variational problem",
                         "Expecting second argument to be a Function")
    return u

def _extract_bcs(bcs):
    "Extract and check argument bcs"
    if bcs is None:
        bcs = []
    elif not isinstance(bcs, (list, tuple)):
        bcs = [bcs]
    for bc in bcs:
        if not isinstance(bc, cpp.BoundaryCondition):
            cpp.dolfin_error("solving.py",
                             "solve variational problem",
                             "Unable to extract boundary condition arguments")
    return bcs
