"""
This test is quite complex, as it instantiate logic formulas to test the
subprogram wrappers generated by TGen.

Subprogram wrappers check that the precondition of the original subprogram
holds before actually calling it. Note that in the subprogram wrapper, we
actually check the inputs over the DNF of the precondition, so we need to check
that the transformation TGen operates is safe.

To do this, we instantiate logic formulas, and generate their truth table. For
each formula, we generate a subprogram whose precondition is the logic formula
(and whose parameters are the formula litterals). We then check that the
original subprogram and the tgen wrapper subprogram behaviors are consistent.

We thus generate the truth table of the formula, and for each valuation:
   * If the formula yields F for the valuation, then we check that the original
     subprogram yields an Ada.Assertions.Assertion_Error exception, and that
     the wrapper yields a TGen.Precondition_Error exception.
   * If the formula yields T for the valuation, then we simply call the
     original and the wrapper subprogram: they should not raise any exception.
"""


from e3.fs import mkdir

from drivers.utils import run

import os
import shutil


class LogicalFormula:
    def __init__(self, operator, *operands):
        self.operator = operator
        self.operands = operands

    def __invert__(self):
        return LogicalFormula("not", self)

    def __and__(self, other):
        return LogicalFormula("and", self, other)

    def __or__(self, other):
        return LogicalFormula("or", self, other)

    def literals(self):
        if self.operator == "lit":
            return set([self.operands[0]])
        elif self.operator == "not":
            return self.operands[0].literals()
        elif self.operator == "or" or self.operator == "and":
            result = set()
            for operand in self.operands:
                result = result.union(operand.literals())
            return result

    def evaluate(self, context):
        if self.operator == "not":
            operand_value = self.operands[0].evaluate(context)
            return not operand_value
        elif self.operator == "and":
            operand_values = [operand.evaluate(context) for operand in self.operands]
            return all(operand_values)
        elif self.operator == "or":
            operand_values = [operand.evaluate(context) for operand in self.operands]
            return any(operand_values)
        elif self.operator == "lit":
            return context.get(self.operands[0])
        else:
            raise ValueError("Invalid operator")

    def to_ada(self):
        if self.operator == "lit":
            return self.operands[0]
        elif self.operator == "not":
            return "not (" + self.operands[0].to_ada() + ")"
        elif self.operator == "and":
            result = "(" + self.operands[0].to_ada() + ")"
            for operand in self.operands[1:]:
                result += " and then (" + operand.to_ada() + ")"
            return result
        elif self.operator == "or":
            result = "(" + self.operands[0].to_ada() + ")"
            for operand in self.operands[1:]:
                result += " or else (" + operand.to_ada() + ")"
            return result


# Check different kind of formulas

# Generate the truth table, and generate code that both


def generate_truth_formula(formula):
    literals = list(formula.literals())
    result = []

    for i in range(2 ** len(literals)):
        context = {}
        for j, lit in enumerate(literals):
            context[lit] = True if (i & (2**j)) >> j else False

        # Evaluate the formula using the current literal valuations
        val = formula.evaluate(context)
        result.append((context, val))

    return result


def generate_call(subp_spec, valuations):
    result = subp_spec + str("(")
    result += ",".join([f"{lit} => {val}" for (lit, val) in valuations.items()])
    result += ")"
    return result


def check_exception(indent, file_adb, stmt, exc):
    t = indent
    tt = t + "   "
    ttt = t + 2 * "   "
    file_adb.write(f"{t}begin\n")
    file_adb.write(f"{tt}{stmt}")

    # As we expect an exception, we should enter the exception handler. Raise
    # an uncaught exception if we don't.
    file_adb.write(f"{tt}raise Program_Error;\n")

    file_adb.write(f"{t}exception\n")
    file_adb.write(f"{tt}when {exc} =>\n")
    file_adb.write(f"{ttt}null;\n")
    file_adb.write(f"{t}end;\n")


def generate_function(pkg_adb, pkg_ads, test_pkg_wrapper_adb, formula):
    t = "   "

    subp_name = "F_" + str(hash(formula))
    subp_spec = "procedure " + subp_name
    literals = list(formula.literals())
    subp_spec += " (" + literals[0]
    for lit in literals[1:]:
        subp_spec += ", " + lit
    subp_spec += " : Boolean)"

    pkg_ads.write(t + subp_spec)
    pkg_ads.write(" with Pre => " + formula.to_ada() + ";\n")
    pkg_adb.write(f"{t}{subp_spec} is\n")
    pkg_adb.write(f"{2*t}begin\n")
    pkg_adb.write(f"{3*t}null;\n")
    pkg_adb.write(f"{2*t}end {subp_name};\n")

    # Generate the truth table for the formula, and then check
    truth_table = generate_truth_formula(formula)

    # Now generate the test
    test_pkg_wrapper_adb.write(f"\n{t}--  Start of formula\n\n")
    for valuations, result in truth_table:
        call_to_orig = (
            generate_call(
                f"Pkg.{subp_name}",
                valuations,
            )
            + ";\n"
        )
        call_to_wrapper = (
            generate_call(f"Pkg.TGen_Wrappers.{subp_name}", valuations) + ";\n"
        )

        # If the precondition is supposed to evaluate to False, check that we
        # get an exception
        if result:
            test_pkg_wrapper_adb.write(t + call_to_orig)
            test_pkg_wrapper_adb.write(t + call_to_wrapper)
        else:
            check_exception(
                t, test_pkg_wrapper_adb, call_to_orig, "Ada.Assertions.Assertion_Error"
            )
            check_exception(
                t, test_pkg_wrapper_adb, call_to_wrapper, "TGen.Precondition_Error"
            )
    test_pkg_wrapper_adb.write(f"\n{t}--  End of formula\n")


mkdir("test")
mkdir(os.path.join("test", "obj"))
mkdir(os.path.join("obj"))
pkg_adb = open(os.path.join("test", "pkg.adb"), "w")
pkg_ads = open(os.path.join("test", "pkg.ads"), "w")

orig_project = open(os.path.join("test", "test.gpr"), "w")

test_pkg_wrapper_prj = open("test_pkg_wrapper.gpr", "w")
test_pkg_wrapper_adb = open("test_pkg_wrapper.adb", "w")

orig_project.write(
    """
project Test is
   for Source_Dirs use (".");
   for Object_Dir use "obj";
end Test;
"""
)

test_pkg_wrapper_prj.write(
    """
with "test/test.gpr";
with "tgen_support/tgen_support.gpr";

project Test_Pkg_Wrapper is
   for Source_Dirs use (".");
   for Object_Dir use "obj";
   for Main use ("test_pkg_wrapper.adb");
end Test_Pkg_Wrapper;
"""
)

pkg_adb.write("package body Pkg is\n")
pkg_ads.write("package Pkg is\n")

test_pkg_wrapper_adb.write("with Ada.Assertions;\n")
test_pkg_wrapper_adb.write("with TGen;\n")
test_pkg_wrapper_adb.write("with Pkg.TGen_Wrappers;\n")
test_pkg_wrapper_adb.write("with Pkg;\n")
test_pkg_wrapper_adb.write("procedure Test_Pkg_Wrapper is\n")
test_pkg_wrapper_adb.write("begin\n")

# Instantiate a pool of literals for formulas
a = LogicalFormula("lit", "a")
b = LogicalFormula("lit", "b")
c = LogicalFormula("lit", "c")
d = LogicalFormula("lit", "d")
e = LogicalFormula("lit", "e")

formulas = [
    # Simple formulas
    (~a) | b,
    a | b | c,
    # More complex formula, checking distributivity of the or else
    (a & (b | c)) | (d | e),
    # Check the distributivity of the negation using DeMorgan laws
    ~((a & (b | c)) | (d | e)),
]

for formula in formulas:
    generate_function(pkg_adb, pkg_ads, test_pkg_wrapper_adb, formula)

pkg_adb.write("end Pkg;\n")
pkg_ads.write("end Pkg;\n")
test_pkg_wrapper_adb.write("end Test_Pkg_Wrapper;\n")
test_pkg_wrapper_prj.close()
orig_project.close()
pkg_adb.close()
pkg_ads.close()
test_pkg_wrapper_adb.close()

# Wrap everything together: first, generate tgen support library

laltools_root = os.path.dirname(os.path.dirname(shutil.which("gnattest")))
templates_path = os.path.join(laltools_root, "share", "tgen", "templates")
pkg_ads_path = os.path.join("test", "pkg.ads")
tgen_marsh_path = str(shutil.which("tgen_marshalling"))
run(
    f"tgen_marshalling -P test/test.gpr --templates-dir='{templates_path}'"
    f" -o tgen_support '{pkg_ads_path}'"
)

# Then, build the test project
run("gprbuild -q -P test_pkg_wrapper.gpr -cargs:Ada -gnata")

# Run it. We don't expect any error
run("obj/test_pkg_wrapper")
