#include "assignment_cost_generator_factory.h"

#include "successor_generator_factory.h"
#include "successor_generator_internals.h"

#include "sampling.h"

#include "../utils/collections.h"
#include "../utils/memory.h"

#include <algorithm>
#include <cassert>

using namespace std;

/*
  The key ideas of the construction algorithm are as follows.

  Initially, we sort the preconditions of the operators
  lexicographically.

  We then group the operators by the *first variable* to be tested,
  forming a group for each variable and possibly a special group for
  operators with no precondition. (This special group forms the base
  case of the recursion, leading to a leaf node in the successor
  generator.) Each group forms a contiguous subrange of the overall
  operator sequence.

  We then further group each subsequence (except for the special one)
  by the *value* that the given variable is tested against, again
  obtaining contiguous subranges of the original operator sequence.

  For each of these subranges, we "pop" the first condition and then
  recursively apply the same construction algorithm to compute a child
  successor generator for this subset of operators.

  Successor generators for different values of the same variable are
  then combined into a switch node, and the generated switch nodes for
  different variables are combined into a fork node.

  The important property of lexicographic sorting that we exploit here
  is that if the original sequence is sorted, then all subsequences we
  must consider recursively are also sorted. Crucially, this remains
  true when we pop the first condition, because this popping always
  happens within a subsequence where all operators have the *same*
  first condition.

  To make the implementation more efficient, we do not physically pop
  conditions but only keep track of how many conditions have been
  dealt with so far, which is simply the recursion depth of the
  "construct_recursive" function.

  Because we only consider contiguous subranges of the operator
  sequence and never need to modify any of the data describing the
  operators, we can simply keep track of the current operator sequence
  by a begin and end index into the overall operator sequence.
*/

namespace assignment_cost_generator {
    using namespace successor_generator;
    using namespace sampling;

vector<FactPair> sort_assignments(vector<FactPair> assignments) {
    sort(assignments.begin(), assignments.end());
    return assignments;
}

AssignmentInfo::AssignmentInfo(
        size_t assignment_id,
        const PartialAssignment &assignment,
        int cost )
        : assignment_id(assignment_id),
          assignments(sort_assignments(assignment.get_assigned_facts())),
          cost(cost) {}

bool AssignmentInfo::operator<(const AssignmentInfo &other) const {
    return assignments < other.assignments;
}

size_t AssignmentInfo::get_id() const {
    return assignment_id;
}

int AssignmentInfo::get_cost() const {
    return cost;
}
// Returns -1 as a past-the-end sentinel.
int AssignmentInfo::get_var(int depth) const {
    if (depth == static_cast<int>(assignments.size())) {
        return -1;
    } else {
        return assignments[depth].var;
    }
}

int AssignmentInfo::get_value(int depth) const {
    return assignments[depth].value;
}



const AssignmentInfo &AssignmentGrouper::get_current_assignment_info() const {
    assert(!range.empty());
    return assignment_infos[range.begin];
}

int AssignmentGrouper::get_current_group_key() const {
    const AssignmentInfo &op_info = get_current_assignment_info();
    if (group_by == GroupAssignmentsBy::VAR) {
        return op_info.get_var(depth);
    } else {
        assert(group_by == GroupAssignmentsBy::VALUE);
        return op_info.get_value(depth);
    }
}
AssignmentGrouper::AssignmentGrouper(
        const vector<AssignmentInfo> &assignment_infos,
        int depth,
        GroupAssignmentsBy group_by,
        OperatorRange range)
        : assignment_infos(assignment_infos),
          depth(depth),
          group_by(group_by),
          range(range) {
}

bool AssignmentGrouper::done() const {
    return range.empty();
}

pair<int, OperatorRange> AssignmentGrouper::next() {
    assert(!range.empty());
    int key = get_current_group_key();
    int group_begin = range.begin;
    do {
        ++range.begin;
    } while (!range.empty() && get_current_group_key() == key);
    OperatorRange group_range(group_begin, range.begin);
    return make_pair(key, group_range);
}

// START FACTORY

BaseAssignmentCostGeneratorFactory::BaseAssignmentCostGeneratorFactory(
        const TaskProxy &task_proxy)
    : task_proxy(task_proxy) {}

BaseAssignmentCostGeneratorFactory::~BaseAssignmentCostGeneratorFactory() = default;

GeneratorPtr BaseAssignmentCostGeneratorFactory::construct_fork(
    vector<GeneratorPtr> nodes) const {
    int size = nodes.size();
    if (size == 1) {
        return move(nodes.at(0));
    } else if (size == 2) {
        return utils::make_unique_ptr<GeneratorForkBinary>(
            move(nodes.at(0)), move(nodes.at(1)));
    } else {
        /* This general case includes the case size == 0, which can
           (only) happen for the root for tasks with no operators. */
        return utils::make_unique_ptr<GeneratorForkMulti>(move(nodes));
    }
}

GeneratorPtr BaseAssignmentCostGeneratorFactory::construct_leaf(
    OperatorRange range) const {
    assert(!range.empty());
    int min_cost = -1;
    while (range.begin != range.end) {
        if (min_cost == -1 || assignment_infos[range.begin].get_cost() < min_cost) {
            min_cost = assignment_infos[range.begin].get_cost();
        }
        ++range.begin;
    }
    return utils::make_unique_ptr<GeneratorLeafSingle>(OperatorID(min_cost));
}

GeneratorPtr BaseAssignmentCostGeneratorFactory::construct_switch(
    int switch_var_id, ValuesAndGenerators values_and_generators) const {
    VariablesProxy variables = task_proxy.get_variables();
    int var_domain = variables[switch_var_id].get_domain_size();
    int num_children = values_and_generators.size();
    bool covers_all_values = var_domain == num_children;
    assert(num_children > 0);

    if (num_children == 1) {
        int value = values_and_generators[0].first;
        GeneratorPtr generator = move(values_and_generators[0].second);
        return utils::make_unique_ptr<GeneratorSwitchSingle>(
            switch_var_id, value, move(generator), covers_all_values);
    }

    int vector_bytes = utils::estimate_vector_bytes<GeneratorPtr>(var_domain);
    int hash_bytes = utils::estimate_unordered_map_bytes<int, GeneratorPtr>(num_children);
    if (hash_bytes < vector_bytes) {
        unordered_map<int, GeneratorPtr> generator_by_value;
        for (auto &item : values_and_generators)
            generator_by_value[item.first] = move(item.second);
        return utils::make_unique_ptr<GeneratorSwitchHash>(
            switch_var_id, move(generator_by_value), covers_all_values);
    } else {
        vector<GeneratorPtr> generator_by_value(var_domain);
        for (auto &item : values_and_generators)
            generator_by_value[item.first] = move(item.second);
        return utils::make_unique_ptr<GeneratorSwitchVector>(
            switch_var_id, move(generator_by_value), covers_all_values);
    }
}

GeneratorPtr BaseAssignmentCostGeneratorFactory::construct_recursive(
    int depth, OperatorRange range) const {
    vector<GeneratorPtr> nodes;
    AssignmentGrouper grouper_by_var(
        assignment_infos, depth, GroupAssignmentsBy::VAR, range);
    while (!grouper_by_var.done()) {
        auto var_group = grouper_by_var.next();
        int var = var_group.first;
        OperatorRange var_range = var_group.second;

        if (var == -1) {
            // Handle a group of immediately applicable operators.
            nodes.push_back(construct_leaf(var_range));
        } else {
            // Handle a group of operators sharing the first precondition variable.
            ValuesAndGenerators values_and_generators;
            AssignmentGrouper grouper_by_value(
                assignment_infos, depth, GroupAssignmentsBy::VALUE, var_range);
            while (!grouper_by_value.done()) {
                auto value_group = grouper_by_value.next();
                int value = value_group.first;
                OperatorRange value_range = value_group.second;

                values_and_generators.emplace_back(
                    value, construct_recursive(depth + 1, value_range));
            }

            nodes.push_back(construct_switch(
                                var, move(values_and_generators)));
        }
    }
    return construct_fork(move(nodes));
}

AssignmentCostGeneratorFactory::AssignmentCostGeneratorFactory(
        const TaskProxy &task_proxy, const PartialAssignmentRegistry &registry,
        const utils::HashMap<size_t, int> &id2costs)
        : BaseAssignmentCostGeneratorFactory(task_proxy),
          registry(registry),
          id2costs(id2costs) {}

AssignmentCostGeneratorFactory::~AssignmentCostGeneratorFactory()  = default;
GeneratorPtr AssignmentCostGeneratorFactory::create() {
    assignment_infos.reserve(id2costs.size());
    for(auto iter = id2costs.begin(); iter != id2costs.end(); ++iter) {
        assignment_infos.emplace_back(iter->first,
                registry.lookup_by_id(iter->first), iter->second);
    }

    /* Use stable_sort rather than sort for reproducibility.
       This amounts to breaking ties by operator ID. */
    stable_sort(assignment_infos.begin(), assignment_infos.end());

    OperatorRange full_range(0, assignment_infos.size());
    GeneratorPtr root = construct_recursive(0, full_range);
    assignment_infos.clear();
    return root;
}

AssignmentUnitCostGeneratorFactory::AssignmentUnitCostGeneratorFactory(
        const TaskProxy &task_proxy,
        const std::vector<PartialAssignment> assignments, const int cost)
        : BaseAssignmentCostGeneratorFactory(task_proxy),
          assignments(assignments),
          cost(cost) {}


AssignmentUnitCostGeneratorFactory::~AssignmentUnitCostGeneratorFactory() = default;
GeneratorPtr AssignmentUnitCostGeneratorFactory::create() {
    assignment_infos.reserve(assignments.size());
    for(auto iter = assignments.begin(); iter != assignments.end(); ++iter) {
        assignment_infos.emplace_back(-1, *iter, cost);
    }

    /* Use stable_sort rather than sort for reproducibility.
       This amounts to breaking ties by operator ID. */
    stable_sort(assignment_infos.begin(), assignment_infos.end());

    OperatorRange full_range(0, assignment_infos.size());
    GeneratorPtr root = construct_recursive(0, full_range);
    assignment_infos.clear();
    return root;
}
}
