//
// Created by Martin Blicha on 20.10.18.
//

#ifndef OPENSMT_BOOLREWRITING_H
#define OPENSMT_BOOLREWRITING_H

#include "PTRef.h"
#include "Logic.h"

#include <vector>
#include <unordered_map>


void computeIncomingEdges(const Logic& logic, PTRef tr, std::unordered_map<PTRef,int,PTRefHash>& PTRefToIncoming);

PTRef rewriteMaxArityAggresive(Logic & logic, PTRef root);

PTRef rewriteMaxArityClassic(Logic & logic, PTRef root);

PTRef simplifyUnderAssignment(Logic & logic, PTRef root);

PTRef simplifyUnderAssignment_Aggressive(PTRef root, Logic & logic);

std::vector<PTRef> getPostOrder(PTRef root, Logic& logic);

std::unordered_map<PTRef, PTRef, PTRefHash> getImmediateDominators(PTRef root, Logic & logic);

template<typename T>
PTRef mergeAndOrArgs(Logic & logic, PTRef tr, Map<PTRef,PTRef,PTRefHash>& cache, T doNotMerge)
{
    assert(logic.isAnd(tr) || logic.isOr(tr));
    const Pterm& t = logic.getPterm(tr);
    SymRef sr = t.symb();
    vec<PTRef> new_args;
    bool changed = false;
    for (int i = 0; i < t.size(); i++) {
        PTRef subst = cache[t[i]];
        changed |= (subst != t[i]);
        if (logic.getSymRef(subst) == sr && !doNotMerge(t[i])) {
            changed = true;
            const Pterm& substs_t = logic.getPterm(subst);
            for (int j = 0; j < substs_t.size(); j++)
                new_args.push(substs_t[j]);
        }
        else {
            new_args.push(subst);
        }
    }
    if (!changed) { return tr; }
    PTRef new_tr = (sr == logic.getSym_and() ? logic.mkAnd(std::move(new_args)) : logic.mkOr(std::move(new_args)));
    return new_tr;
}


template<typename T>
PTRef rewriteMaxArity(Logic & logic, const PTRef root, T doNotRewrite) {
    vec<PTRef> unprocessed_ptrefs;
    unprocessed_ptrefs.push(root);
    Map<PTRef,PTRef,PTRefHash> cache;

    while (unprocessed_ptrefs.size() > 0) {
        PTRef tr = unprocessed_ptrefs.last();
        if (cache.has(tr)) {
            unprocessed_ptrefs.pop();
            continue;
        }

        bool unprocessed_children = false;
        const Pterm& t = logic.getPterm(tr);
        for (int i = 0; i < t.size(); i++) {
            if (cache.has(t[i])) { continue; }
            if (logic.isBooleanOperator(t[i])) {
                unprocessed_ptrefs.push(t[i]);
                unprocessed_children = true;
            }
            else if (logic.isAtom(t[i])) {
                assert(!cache.has(t[i]));
                cache.insert(t[i], t[i]);
            } else {
                assert(false);
            }
        }
        if (unprocessed_children)
            continue;

        unprocessed_ptrefs.pop();
        SymRef symRef = t.symb();
        assert(logic.isBooleanOperator(symRef));

        PTRef result = PTRef_Undef;
        if (logic.isAnd(symRef) or logic.isOr(symRef)) {
            result = ::mergeAndOrArgs(logic, tr, cache, doNotRewrite);
        } else if (logic.isNot(symRef)) {
            PTRef child = t[0];
            PTRef newChild = cache[child];
            result = child == newChild ? tr : logic.mkNot(newChild);
        } else { // general connective
            vec<PTRef> newArgs;
            newArgs.capacity(t.size());
            bool changed = false;
            for (PTRef child : t) {
                PTRef newChild = cache[child];
                changed |= (newChild != child);
                newArgs.push(newChild);
            }
            result = changed ? logic.insertTerm(symRef, std::move(newArgs)) : tr;
        }
        assert(result != PTRef_Undef);
        assert(!cache.has(tr));
        cache.insert(tr, result);
    }
    PTRef top_tr = cache[root];
    return top_tr;
}

#endif //OPENSMT_BOOLREWRITING_H
