#include "abstract_state.h"

#include "refinement_hierarchy.h"
#include "utils.h"

#include "../utils/memory.h"

#include <algorithm>
#include <cassert>
#include <unordered_set>

using namespace std;

namespace cegar {
AbstractState::AbstractState(
    int state_id, NodeID node_id, CartesianSet &&cartesian_set)
    : state_id(state_id),
      node_id(node_id),
      cartesian_set(move(cartesian_set)) {
}

int AbstractState::count(int var) const {
    return cartesian_set.count(var);
}

bool AbstractState::contains(int var, int value) const {
    return cartesian_set.test(var, value);
}

pair<CartesianSet, CartesianSet> AbstractState::split_domain(
    int var, const vector<int> &wanted) const {
    int num_wanted = wanted.size();
    utils::unused_variable(num_wanted);
    // We can only refine for variables with at least two values.
    assert(num_wanted >= 1);
    assert(cartesian_set.count(var) > num_wanted);

    CartesianSet v1_cartesian_set(cartesian_set);
    CartesianSet v2_cartesian_set(cartesian_set);

    v2_cartesian_set.remove_all(var);
    for (int value : wanted) {
        // The wanted value has to be in the set of possible values.
        assert(cartesian_set.test(var, value));

        // In v1 var can have all of the previous values except the wanted ones.
        v1_cartesian_set.remove(var, value);

        // In v2 var can only have the wanted values.
        v2_cartesian_set.add(var, value);
    }
    assert(v1_cartesian_set.count(var) == cartesian_set.count(var) - num_wanted);
    assert(v2_cartesian_set.count(var) == num_wanted);
    return make_pair(v1_cartesian_set, v2_cartesian_set);
}

static bool fact_occurs_as_effect_condition(
    int var, int val, const vector<FactoredEffectPair> &factored_effect_pairs) {
    return any_of(factored_effect_pairs.begin(), factored_effect_pairs.end(),
           [&](const FactoredEffectPair &fep){
        return fep.var == var && fep.value_before == val;
    });
}

CartesianSet AbstractState::regress(
    const OperatorProxy &op,
    const vector<FactoredEffectPair> &factored_effect_pairs) const {
    CartesianSet regression = cartesian_set;

    /*
      As described in the paper (buechner-et-al-icaps2022wshsdip), there are two
      reasons why a value fact V=v can be true in a state after applying an
      operator o:
      (1) there is no effect in o that has effect condition V=v, or
      (2) there is an effect <V,u,v> in o.

      The first for-loop below takes care of (1) by removing all those values
      from the regression that do occur as effect conditions.

      The second for-loop below takes care of (2) by adding for each effect that
      generates v its condition u to the regression.
    */
    for (int var = 0; var < cartesian_set.get_number_of_variables(); ++var) {
        for (int value : cartesian_set.get_values(var)) {
            if (fact_occurs_as_effect_condition(
                var, value, factored_effect_pairs)) {
                regression.remove(var, value);
            }
        }
    }
    for (const FactoredEffectPair &fep : factored_effect_pairs) {
        if (cartesian_set.test(fep.var, fep.value_after)) {
            regression.add(fep.var, fep.value_before);
        }
    }
    /*
      The third for-loop ensures that if o has a precondition on V, then only
      the value specified in the precondition is viable in the regression
      through o.
    */
    for (FactProxy precondition : op.get_preconditions()) {
        int var_id = precondition.get_variable().get_id();
        int val = precondition.get_value();
        if (regression.test(var_id, val)) {
            regression.set_single_value(var_id, val);
        } else {
            /*
              This case captures the possibility that operator *op* has a
              precondition on variable *var_id* but the value *val* required in
              the precondition is not part of the regression computed above.

              If this is the case, the regression should actually be empty
              because no state where *op* is applicable can achieve the abstract
              state over which we regress, which is why we remove all values
              for this variable.
            */
            regression.remove_all(var_id);
        }
    }
    return regression;
}

bool AbstractState::domain_subsets_intersect(const AbstractState &other, int var) const {
    return cartesian_set.intersects(other.cartesian_set, var);
}

bool AbstractState::includes(const State &concrete_state) const {
    for (FactProxy fact : concrete_state) {
        if (!cartesian_set.test(fact.get_variable().get_id(), fact.get_value()))
            return false;
    }
    return true;
}

bool AbstractState::includes(const vector<FactPair> &facts) const {
    for (const FactPair &fact : facts) {
        if (!cartesian_set.test(fact.var, fact.value))
            return false;
    }
    return true;
}

bool AbstractState::includes(const AbstractState &other) const {
    return cartesian_set.is_superset_of(other.cartesian_set);
}

int AbstractState::get_id() const {
    return state_id;
}

NodeID AbstractState::get_node_id() const {
    return node_id;
}

unique_ptr<AbstractState> AbstractState::get_trivial_abstract_state(
    const vector<int> &domain_sizes) {
    return utils::make_unique_ptr<AbstractState>(0, 0, CartesianSet(domain_sizes));
}

std::vector<int> AbstractState::get_values(int var) const {
    return cartesian_set.get_values(var);
}
}
