__author__ = "Marie E. Rognes (meg@simula.no)"
__copyright__ = "Copyright (C) 2011 Marie Rognes"
__license__  = "GNU LGPL version 3 or any later version"

# Based on original implementation by Martin Alnes and Anders Logg

# Last changed: 2011-02-21

import includes as incl
from functionspace import *
from form import generate_form
from capsules import UFCElementNames

__all__ = ["generate_dolfin_code"]

# NB: generate_dolfin_namespace(...) assumes that if a coefficient has
# the same name in multiple forms, it is indeed the same coefficient:
parameters = {"use_common_coefficient_names": True}

#-------------------------------------------------------------------------------
def generate_dolfin_code(prefix, header, forms,
                         common_function_space=False, add_guards=False):
    """Generate complete dolfin wrapper code with given generated names.

    @param prefix:
        String, prefix for all form names.
    @param header:
        Code that will be inserted at the top of the file.
    @param forms:
        List of UFCFormNames instances or single UFCElementNames.
    @param common_function_space:
        True if common function space, otherwise False
    @param add_guards:
        True iff guards (ifdefs) should be added
    """

    # Generate dolfin namespace
    namespace = generate_dolfin_namespace(prefix, forms, common_function_space)

    # Collect pieces of code
    code = [incl.dolfin_tag, header, incl.stl_includes, incl.dolfin_includes,
            namespace]

    # Add ifdefs/endifs if specified
    if add_guards:
        guard_name = ("%s_h" % prefix).upper()
        preguard = "#ifndef %s\n#define %s\n" % (guard_name, guard_name)
        postguard = "\n#endif\n\n"
        code = [preguard] + code + [postguard]

    # Return code
    return "\n".join(code)

#-------------------------------------------------------------------------------
def generate_dolfin_namespace(prefix, forms, common_function_space=False):

    # Allow forms to represent a single space, and treat separately
    if isinstance(forms, UFCElementNames):
        return generate_single_function_space(prefix, forms)

    # Extract (common) coefficient spaces
    assert(parameters["use_common_coefficient_names"])
    spaces = extract_coefficient_spaces(forms)

    # Generate code for common coefficient spaces
    code = [apply_function_space_template(*space) for space in spaces]

    # Generate code for forms
    code += [generate_form(form, "Form_%d"%i) for (i, form) in enumerate(forms)]

    # Generate namespace typedefs (Bilinear/Linear & Test/Trial/Function)
    code += [generate_namespace_typedefs(forms, common_function_space)]

    # Wrap code in namespace block
    code = "\nnamespace %s\n{\n\n%s\n}" % (prefix, "\n".join(code))

    # Return code
    return code

#-------------------------------------------------------------------------------
def generate_single_function_space(prefix, space):
    code = apply_function_space_template("FunctionSpace",
                                         space.ufc_finite_element_classnames[0],
                                         space.ufc_dofmap_classnames[0])
    code = "\nnamespace %s\n{\n\n%s\n}" % (prefix, code)
    return code

#-------------------------------------------------------------------------------
def generate_namespace_typedefs(forms, common_function_space):

    # Generate typedefs as (fro, to) pairs of strings
    pairs = []

    # Add typedef for Functional/LinearForm/BilinearForm if only one
    # is present of each
    aliases = ["Functional", "LinearForm", "BilinearForm"]
    for rank in sorted(range(len(aliases)), reverse=True):
        forms_of_rank = [form for form in forms if form.rank == rank]
        if len(forms_of_rank) == 1:
            pairs += [("Form_%s" % forms_of_rank[0].name, aliases[rank])]

    # Keepin' it simple: Add typedef for FunctionSpace if term applies
    if common_function_space and any(form.rank for form in forms):
        pairs += [("Form_0::TestSpace", "FunctionSpace")]

    # Combine data to typedef code
    typedefs = "\n".join("typedef %s %s;" % (to, fro) for (to, fro) in pairs)

    # Return typedefs or ""
    if not typedefs:
        return ""
    return "// Class typedefs\n" + typedefs + "\n"

#-------------------------------------------------------------------------------
