#include "task_proxy.h"

#include "task_utils/causal_graph.h"

#include <iostream>

using namespace std;

const int PartialAssignment::UNASSIGNED = -1;


State::State(const AbstractTask &task, std::vector<int> &&values)
        : PartialAssignment(task, std::move(values)) {
    assert(all_of(values.begin(), values.end(), [] (const int value){
        return value != PartialAssignment::UNASSIGNED;}));
}

const causal_graph::CausalGraph &TaskProxy::get_causal_graph() const {
    return causal_graph::get_causal_graph(task);
}

TaskProxy::TaskProxy(const AbstractTask &task)
    : task(&task) {}

bool contains_mutex_with_variable(
        const AbstractTask *task, size_t var, const vector<int> &values,
        size_t first_var2_index = 0) {
    assert(utils::in_bounds(var, values));
    if (values[var] == PartialAssignment::UNASSIGNED) {
        return false;
    }
    FactPair fp(var, values[var]);

    for (size_t var2 = first_var2_index; var2 < values.size(); ++var2) {
        assert(utils::in_bounds(var2, values));
        if (var2 == var || values[var] == PartialAssignment::UNASSIGNED) {
            continue;
        }
        FactPair fp2(var2, values[var2]);
        if (task->are_facts_mutex(fp, fp2)) {
            return true;
        }
    }
    return false;
}

bool contains_mutex(const AbstractTask *task, const vector<int> &values) {
    for (size_t var = 0; var < values.size(); ++var) {
        if (contains_mutex_with_variable(task, var, values, var + 1)) {
            return true;
        }
    }
    return false;
}

/*
  Replace values[var] with non-mutex value. Return true iff such a
  non-mutex value could be found.
 */
static bool replace_with_non_mutex_value(
        const AbstractTask *task, vector<int> &values,
        const int idx_var, utils::RandomNumberGenerator &rng) {
    utils::in_bounds(idx_var, values);
    int old_value = values[idx_var];
    vector<int> domain(task->get_variable_domain_size(idx_var));
    iota(domain.begin(), domain.end(), 0);
    rng.shuffle(domain);
    for (int new_value : domain) {
        values[idx_var] = new_value;
        if (!contains_mutex_with_variable(task, idx_var, values)) {
            return true;
        }
    }
    values[idx_var] = old_value;
    return false;
}


static const int MAX_TRIES_EXTEND = 10000;
static bool replace_dont_cares_with_non_mutex_values(
        const AbstractTask *task, vector<int> &values,
        utils::RandomNumberGenerator &rng) {
    assert(values.size() == (size_t) task->get_num_variables());
    vector<int> vars_order(task->get_num_variables());
    iota(vars_order.begin(), vars_order.end(), 0);

    for (int round = 0; round < MAX_TRIES_EXTEND; ++round) {
        bool invalid = false;
        rng.shuffle(vars_order);
        vector<int> new_values = values;

        for (int idx_var : vars_order) {
            if (new_values[idx_var] == PartialAssignment::UNASSIGNED) {
                if (!replace_with_non_mutex_value(
                        task, new_values, idx_var, rng)) {
                    invalid = true;
                    break;
                }
            }
        }
        if (!invalid) {
            values = new_values;
            return true;
        }
    }
    return false;
}

bool PartialAssignment::violates_mutexes() const {
    return contains_mutex(task, values);
}
pair<bool, State> PartialAssignment::get_full_state(
        bool check_mutexes,
        utils::RandomNumberGenerator &rng) const {
    vector<int> new_values = values;
    bool success = true;
    if (check_mutexes) {
        if (contains_mutex(task, new_values)) {
            return make_pair(false, State(*task, move(new_values)));
        } else {
            success = replace_dont_cares_with_non_mutex_values(
                    task, new_values, rng);
        }

    } else {
        for (VariableProxy var : VariablesProxy(*task)) {
            int &value = new_values[var.get_id()];
            if (value == PartialAssignment::UNASSIGNED) {
                int domain_size = var.get_domain_size();
                value = rng(domain_size);
            }
        }
    }
    return make_pair(success, State(*task, move(new_values)));
}


pair<bool, State> TaskProxy::convert_to_full_state(
        PartialAssignment &assignment,
        bool check_mutexes, utils::RandomNumberGenerator &rng) const {
    return assignment.get_full_state(check_mutexes, rng);
}