"""
Check that diagnostics generated by solver resolution failures work as
expected, and in particular that all attempted solutions resulting in a
predicate failure correctly yield a diagnostic where the relevant logic
contexts are bound. Hence, this also checks that the DSL allows attaching
logic contexts to atoms.
"""

from langkit.dsl import (
    ASTNode, Field, T, UserField, abstract, synthetic
)
from langkit.envs import EnvSpec, add_env, add_to_env_kv
from langkit.expressions import (
    AbstractKind, And, Bind, Cond, Entity, LogicFalse, No, Predicate, Self,
    Var, langkit_property, lazy_field
)

from utils import build_and_run


class FooNode(ASTNode):
    @lazy_field()
    def number_type():
        return T.NumType.new()

    @lazy_field()
    def str_type():
        return T.StrType.new()

    # This is just to make sure that SolverResult is correctly exposed to the
    # DSL.

    @langkit_property(public=True, return_type=T.SolverDiagnostic)
    def get_first_diag(r=T.SolverResult):
        return r.diagnostics.at(0)


class Block(FooNode):
    stmts = Field()
    env_spec = EnvSpec(add_env())


class Identifier(FooNode):
    token_node = True

    ref_var = UserField(type=T.LogicVar, public=False)

    @langkit_property(return_type=T.TypeDecl.entity)
    def designated_type():
        return Cond(
            Self.symbol == "number",
            Self.unit.root.number_type.as_bare_entity,

            Self.symbol == "string",
            Self.unit.root.str_type.as_bare_entity,

            No(TypeDecl.entity)
        )

    @langkit_property(return_type=T.Equation)
    def xref_equation(ctx=T.LogicContext):
        return Bind(Self.ref_var, Self.designated_type,
                    logic_ctx=ctx)


@abstract
class BaseTypeDecl(FooNode):
    @langkit_property(return_type=T.Bool,
                      predicate_error="expected $expected, got $Self")
    def match_expected_type(expected=T.BaseTypeDecl.entity):
        return expected == Entity


@abstract
class TypeDecl(BaseTypeDecl):
    pass


@synthetic
class NumType(TypeDecl):
    pass


@synthetic
class StrType(TypeDecl):
    pass


@abstract
class Expr(FooNode):
    type_var = UserField(type=T.LogicVar, public=False)

    @langkit_property(return_type=T.Equation)
    def xref_equation():
        return LogicFalse()


class NumberLiteral(Expr):
    token_node = True

    @langkit_property(return_type=T.Equation)
    def xref_equation():
        return Bind(Self.type_var, Self.unit.root.number_type)


class StringLiteral(Expr):
    token_node = True

    @langkit_property(return_type=T.Equation)
    def xref_equation():
        return Bind(Self.type_var, Self.unit.root.str_type)


class ProcDecl(FooNode):
    name = Field()
    first_type = Field()
    second_type = Field()

    @langkit_property(return_type=T.Equation)
    def call_equation(first_arg=T.Expr, second_arg=T.Expr, ctx=T.LogicContext):
        return And(
            Self.first_type.xref_equation(ctx),
            Self.second_type.xref_equation(ctx),

            Predicate(TypeDecl.match_expected_type,
                      first_arg.type_var,
                      Self.first_type.ref_var,
                      error_location=first_arg),
            Predicate(TypeDecl.match_expected_type,
                      second_arg.type_var,
                      Self.second_type.ref_var,
                      error_location=second_arg)
        )

    env_spec = EnvSpec(
        add_to_env_kv(Self.name.symbol, Self)
    )


@abstract
class Resolvable(Expr):
    @langkit_property(return_type=T.SolverResult, public=True,
                      kind=AbstractKind.abstract)
    def resolve():
        pass


class Call(Resolvable):
    name = Field()
    first_arg = Field()
    second_arg = Field()

    @langkit_property()
    def resolve():
        eq = Var(And(
            Entity.first_arg.xref_equation,
            Entity.second_arg.xref_equation,
            Self.children_env.get(Self.name.symbol).logic_any(
                lambda elem: And(
                    Bind(Self.name.ref_var, elem),
                    elem.cast(ProcDecl).call_equation(
                        Self.first_arg,
                        Self.second_arg,
                        ctx=T.LogicContext.new(
                            ref_node=Entity.name,
                            decl_node=elem
                        )
                    )
                )
            )
        ))
        return eq.solve_with_diagnostics


class TypeAssert(Resolvable):
    expr = Field()
    ident = Field()

    @langkit_property()
    def resolve():
        eq = Var(And(
            Entity.expr.xref_equation,
            Entity.ident.xref_equation(No(T.LogicContext)),
            Predicate(TypeDecl.match_expected_type,
                      Self.expr.type_var,
                      Entity.ident.designated_type,
                      error_location=Self.expr)
        ))
        return eq.solve_with_diagnostics


build_and_run(
    lkt_file='expected_concrete_syntax.lkt',
    py_script='main.py',
    types_from_lkt=True,
)
print('Done')
