#include "factored_transition_system.h"

#include "distances.h"
#include "labels.h"
#include "merge_and_shrink_representation.h"
#include "transition_system.h"
#include "utils.h"

#include "../utils/collections.h"
#include "../utils/memory.h"
#include "../utils/system.h"
#include "../utils/timer.h"

#include <cassert>

using namespace std;

namespace merge_and_shrink {
FTSConstIterator::FTSConstIterator(
    const FactoredTransitionSystem &fts,
    bool end)
    : fts(fts), current_index((end ? fts.get_size() : 0)) {
    next_valid_index();
}

void FTSConstIterator::next_valid_index() {
    while (current_index < fts.get_size()
           && !fts.is_active(current_index)) {
        ++current_index;
    }
}

void FTSConstIterator::operator++() {
    ++current_index;
    next_valid_index();
}


FactoredTransitionSystem::FactoredTransitionSystem(
    unique_ptr<Labels> labels,
    vector<unique_ptr<TransitionSystem>> &&transition_systems,
    vector<unique_ptr<MergeAndShrinkRepresentation>> &&mas_representations,
    vector<unique_ptr<Distances>> &&distances,
    const bool compute_init_distances,
    const bool compute_goal_distances,
    Verbosity verbosity,
    const utils::Timer &timer)
    : labels(move(labels)),
      transition_systems(move(transition_systems)),
      mas_representations(move(mas_representations)),
      distances(move(distances)),
      compute_init_distances(compute_init_distances),
      compute_goal_distances(compute_goal_distances),
      num_active_entries(this->transition_systems.size()) {
    for (size_t index = 0; index < this->transition_systems.size(); ++index) {
        if (compute_init_distances || compute_goal_distances) {
            this->distances[index]->compute_distances(
                compute_init_distances, compute_goal_distances, verbosity);
        }
        assert(is_component_valid(index));
    }
    cout << "done creating FTS " << timer() << endl;
}

FactoredTransitionSystem::FactoredTransitionSystem(FactoredTransitionSystem &&other)
    : labels(move(other.labels)),
      transition_systems(move(other.transition_systems)),
      mas_representations(move(other.mas_representations)),
      distances(move(other.distances)),
      compute_init_distances(move(other.compute_init_distances)),
      compute_goal_distances(move(other.compute_goal_distances)),
      num_active_entries(move(other.num_active_entries)) {
    /*
      This is just a default move constructor. Unfortunately Visual
      Studio does not support "= default" for move construction or
      move assignment as of this writing.
    */
}

FactoredTransitionSystem::~FactoredTransitionSystem() {
}

bool FactoredTransitionSystem::apply_abstraction(
    int index,
    const StateEquivalenceRelation &state_equivalence_relation,
    Verbosity verbosity) {
    assert(is_component_valid(index));

    int new_num_states = state_equivalence_relation.size();
    if (new_num_states == transition_systems[index]->get_size()) {
        if (verbosity >= Verbosity::VERBOSE) {
            cout << transition_systems[index]->tag()
                 << "not applying abstraction (same number of states)" << endl;
        }
        return false;
    }

    vector<int> abstraction_mapping = compute_abstraction_mapping(
        transition_systems[index]->get_size(), state_equivalence_relation);

    transition_systems[index]->apply_abstraction(
        state_equivalence_relation, abstraction_mapping, verbosity);
    if (compute_init_distances || compute_goal_distances) {
        distances[index]->apply_abstraction(
            state_equivalence_relation,
            compute_init_distances,
            compute_goal_distances,
            verbosity);
    }
    mas_representations[index]->apply_abstraction_to_lookup_table(
        abstraction_mapping);

    /* If distances need to be recomputed, this already happened in the
       Distances object. */
    assert(is_component_valid(index));
    return true;
}

void FactoredTransitionSystem::assert_index_valid(int index) const {
    assert(utils::in_bounds(index, transition_systems));
    assert(utils::in_bounds(index, mas_representations));
    assert(utils::in_bounds(index, distances));
    if (!(transition_systems[index] && mas_representations[index] && distances[index]) &&
        !(!transition_systems[index] && !mas_representations[index] && !distances[index])) {
        cerr << "Factor at index is in an inconsistent state!" << endl;
        utils::exit_with(utils::ExitCode::CRITICAL_ERROR);
    }
}

bool FactoredTransitionSystem::is_component_valid(int index) const {
    assert(is_active(index));
    if (compute_init_distances && !distances[index]->are_init_distances_computed()) {
        return false;
    }
    if (compute_goal_distances && !distances[index]->are_goal_distances_computed()) {
        return false;
    }
    return transition_systems[index]->are_transitions_sorted_unique();
}

void FactoredTransitionSystem::assert_all_components_valid() const {
    for (size_t index = 0; index < transition_systems.size(); ++index) {
        if (transition_systems[index]) {
            assert(is_component_valid(index));
        }
    }
}

void FactoredTransitionSystem::apply_label_mapping(
    const vector<pair<int, vector<int>>> &label_mapping,
    int combinable_index) {
    assert_all_components_valid();
    for (const auto &new_label_old_labels : label_mapping) {
        assert(new_label_old_labels.first == labels->get_size());
        labels->reduce_labels(new_label_old_labels.second);
    }
    for (size_t i = 0; i < transition_systems.size(); ++i) {
        if (transition_systems[i]) {
            transition_systems[i]->apply_label_reduction(
                label_mapping, static_cast<int>(i) != combinable_index);
        }
    }
    assert_all_components_valid();
}

int FactoredTransitionSystem::merge(
    int index1,
    int index2,
    Verbosity verbosity,
    bool invalidating_merge) {
    assert(is_component_valid(index1));
    assert(is_component_valid(index2));
    transition_systems.push_back(
        TransitionSystem::merge(
            *labels,
            *transition_systems[index1],
            *transition_systems[index2],
            verbosity));
    if (invalidating_merge) {
        distances[index1] = nullptr;
        distances[index2] = nullptr;
        transition_systems[index1] = nullptr;
        transition_systems[index2] = nullptr;
        mas_representations.push_back(
            utils::make_unique_ptr<MergeAndShrinkRepresentationMerge>(
                move(mas_representations[index1]),
                move(mas_representations[index2])));
        mas_representations[index1] = nullptr;
        mas_representations[index2] = nullptr;
    } else {
        unique_ptr<MergeAndShrinkRepresentation> hr1 = nullptr;
        if (dynamic_cast<MergeAndShrinkRepresentationLeaf *>(mas_representations[index1].get())) {
            hr1 = utils::make_unique_ptr<MergeAndShrinkRepresentationLeaf>(
                dynamic_cast<MergeAndShrinkRepresentationLeaf *>
                    (mas_representations[index1].get()));
        } else {
            hr1 = utils::make_unique_ptr<MergeAndShrinkRepresentationMerge>(
                dynamic_cast<MergeAndShrinkRepresentationMerge *>(
                    mas_representations[index1].get()));
        }
        unique_ptr<MergeAndShrinkRepresentation> hr2 = nullptr;
        if (dynamic_cast<MergeAndShrinkRepresentationLeaf *>(mas_representations[index2].get())) {
            hr2 = utils::make_unique_ptr<MergeAndShrinkRepresentationLeaf>(
                        dynamic_cast<MergeAndShrinkRepresentationLeaf *>
                        (mas_representations[index2].get()));
        } else {
            hr2 = utils::make_unique_ptr<MergeAndShrinkRepresentationMerge>(
                        dynamic_cast<MergeAndShrinkRepresentationMerge *>(
                            mas_representations[index2].get()));
        }
        mas_representations.push_back(
            utils::make_unique_ptr<MergeAndShrinkRepresentationMerge>(
                move(hr1),
                move(hr2)));
    }
    const TransitionSystem &new_ts = *transition_systems.back();
    distances.push_back(utils::make_unique_ptr<Distances>(new_ts));
    int new_index = transition_systems.size() - 1;
    // Restore the invariant that distances are computed.
    if (compute_init_distances || compute_goal_distances) {
        distances[new_index]->compute_distances(
            compute_init_distances, compute_goal_distances, verbosity);
    }
    --num_active_entries;
    assert(is_component_valid(new_index));
    return new_index;
}

pair<unique_ptr<MergeAndShrinkRepresentation>, unique_ptr<Distances>>
FactoredTransitionSystem::extract_factor(int index) {
    assert(is_component_valid(index));
    return make_pair(move(mas_representations[index]),
                     move(distances[index]));
}

void FactoredTransitionSystem::statistics(int index) const {
    assert(is_component_valid(index));
    const TransitionSystem &ts = *transition_systems[index];
    ts.statistics();
    const Distances &dist = *distances[index];
    dist.statistics();
}

void FactoredTransitionSystem::dump(int index) const {
    assert_index_valid(index);
    transition_systems[index]->dump_labels_and_transitions();
    mas_representations[index]->dump();
}

bool FactoredTransitionSystem::is_factor_solvable(int index) const {
    assert(is_component_valid(index));
    return transition_systems[index]->is_solvable(*distances[index]);
}

bool FactoredTransitionSystem::is_active(int index) const {
    assert_index_valid(index);
    return transition_systems[index] != nullptr;
}

int FactoredTransitionSystem::get_init_state_goal_distance(int index) const {
    return distances[index]->get_goal_distance(transition_systems[index]->get_init_state());
}

void FactoredTransitionSystem::remove(int index) {
    assert(is_active(index));
    transition_systems[index] = nullptr;
    mas_representations[index] = nullptr;
    distances[index] = nullptr;
}

void FactoredTransitionSystem::dump() const {
    for (size_t i = 0; i < transition_systems.size(); ++i) {
        if (transition_systems[i]) {
            transition_systems[i]->dump_labels_and_transitions();
        }
    }
}
}
