#include "merge_and_shrink_heuristic.h"

#include "abstraction.h"
#include "labels.h"
#include "merge_miasm.h"
#include "merge_strategy.h"
#include "merge_tree.h"
#include "shrink_strategy.h"

#include "../global_state.h"
#include "../globals.h"
#include "../legacy_causal_graph.h"
#include "../option_parser.h"
#include "../plugin.h"
#include "../utils/logging.h"
#include "../utils/system.h"
#include "../utils/timer.h"

#include "util/lattice_varset.h"
#include "util/set_packing.h"

#include "../algorithms/sccs.h"

#include "../task_utils/task_properties.h"
#include "../tasks/root_task.h"

#include <cassert>
#include <vector>
using namespace std;
using namespace utils;

namespace mas_new_lr {
MergeAndShrinkHeuristic::MergeAndShrinkHeuristic(const Options &opts)
    : Heuristic(opts),
      merge_strategy(opts.get<MergeStrategy *>("merge_strategy")),
      shrink_strategy(opts.get<ShrinkStrategy *>("shrink_strategy")),
      use_expensive_statistics(opts.get<bool>("expensive_statistics")),
      cost_type(static_cast<OperatorCost>(opts.get_enum("cost_type"))),
      sum_computation_bound(opts.get<int>("sum_computation_bound")) {
    labels = new Labels(is_unit_cost(), opts, cost_type);
    if (merge_strategy->name() == "miasm" || merge_strategy->name() == "symmetries_miasm") {

        // HACK!
        tasks::g_need_mutex_groups = true;

        /*ifstream fin_SHR_filelist("SHR_filelist");

        string filename_SHR_T, filename_SHR_clusters,
                filename_SHR_packing, filename_SHR_mergetree;

        fin_SHR_filelist >> filename_SHR_T >> filename_SHR_clusters
                >> filename_SHR_packing >> filename_SHR_mergetree;

        //ifstream fin_SHR_T(filename_SHR_T.c_str());
        ifstream fin_SHR_clusters(filename_SHR_clusters.c_str());
        ifstream fin_SHR_packing(filename_SHR_packing.c_str());
        ifstream fin_SHR_mergetree(filename_SHR_mergetree.c_str());

        ofstream fout_SHR_clusters;
        if (!fin_SHR_mergetree.is_open() && !fin_SHR_packing.is_open()
                && !fin_SHR_clusters.is_open()) {
            fout_SHR_clusters.open(filename_SHR_clusters.c_str());
        }

        ofstream fout_SHR_packing;
        if (!fin_SHR_packing.is_open() && !fin_SHR_clusters.is_open()) {
            fout_SHR_packing.open(filename_SHR_packing.c_str());
        }

        ofstream fout_SHR_mergetree;
        if (!fin_SHR_mergetree.is_open()) {
            fout_SHR_mergetree.open(filename_SHR_mergetree.c_str());
        }*/

        //        sum_total_enqueued_bound = -1;
        //        fin_SHR_T >> sum_total_enqueued_bound;
        //        assert(sum_total_enqueued_bound != -1);

        //fin_SHR_T >> sum_computation_bound;
        assert(sum_computation_bound != -1);

        //if (fout_SHR_clusters.is_open()) {
//            reach_bound = shrink_strategy->get_max_states_options();
            reach_bound = opts.get<int>("miasm_max_states");

            search_promising_clusters();

            /*for (map<set<int>, double>::iterator i = weight.begin();
                    i != weight.end(); i++) {
                cerr << i->first << ": " << i->second << endl;
                assert(ratio_table.count(i->first));
                vector<int> v(i->first.begin(), i->first.end());
                int tns = total_count(v);
                fout_SHR_clusters << i->first << ": " << i->second
                        << "=-log(" << ratio_table[i->first] << ")"
                        << "=-log(" << (int) (ratio_table[i->first] * tns + 0.5)
                        << "/" << tns << ")" << endl;

            }*/
        //} else {
            /* TO DO read in weights */
        //}

        set<set<int> > packing;
        double total_weight;
        //if (fout_SHR_packing.is_open()) {
            cerr << "constructing a packing\n";
            total_weight = set_packing::greedy_packing(weight, packing);
            cerr << packing << endl << total_weight << endl;
            //fout_SHR_packing << packing << endl << total_weight << endl;

        /*} else {
            cerr << "read in packing\n";
            if (!(fin_SHR_packing >> packing)) {
                cerr << "no non-atomic cluster found!\n" << endl;
                exit(0);
            } else {
                fin_SHR_packing >> total_weight;
                cerr << packing << endl << total_weight << endl;
            }
        }*/
        /* if each variable forms a cluster, then SHR think it is not
         * reasonable to merging using those cluster. search stops */
        if (static_cast<int>(packing.size()) == tasks::g_root_task->get_num_variables() || packing.size() == 0) {
            cerr << "packing: " << packing << endl;
            cerr << "no non-atomic cluster found!" << endl;
            cout << "switching to DFP merge strategy" << endl;
        } else {
            //if (fout_SHR_mergetree.is_open()) {
                dynamic_cast<MergeMiasm *>(merge_strategy)->merge_tree =
                    MergeTree::merge_tree_clusters(packing,
                        CLUSTER_EXTERNAL_LINEAR_SMALL_CG_GOAL_LEVEL,
                        CLUSTER_INTERNAL_LINEAR_REVERSE_LEVEL);
                dynamic_cast<MergeMiasm *>(merge_strategy)->merge_tree_set = true;
                //            cerr << merge_strategy->merge_tree;
                //fout_SHR_mergetree << merge_strategy->merge_tree;
            //} else {

            //}

            for (tree<node_t>::pre_order_iterator tree_it = dynamic_cast<MergeMiasm *>(merge_strategy)->merge_tree.begin();
                 tree_it != dynamic_cast<MergeMiasm *>(merge_strategy)->merge_tree.end(); ++tree_it) {
                node_t &node = *tree_it;
                if (node.size() > 1) {
                    node.clear();
                }
            }

            cout << "Done initializing MIASM merge tree" << endl;
        }
    }

    initialize();
}

MergeAndShrinkHeuristic::~MergeAndShrinkHeuristic() {
    delete merge_strategy;
    delete shrink_strategy;
    delete labels;
}

void MergeAndShrinkHeuristic::report_peak_memory_delta(bool final) const {
    if (final)
        cout << "Final";
    else
        cout << "Current";
    cout << " peak memory increase of merge-and-shrink computation: "
         << utils::get_peak_memory_in_kb() - starting_peak_memory << " KB"
         << endl;
}

void MergeAndShrinkHeuristic::dump_options() const {
    merge_strategy->dump_options();
    shrink_strategy->dump_options();
    labels->dump_options();
    cout << "Expensive statistics: "
         << (use_expensive_statistics ? "enabled" : "disabled") << endl;
}

void MergeAndShrinkHeuristic::warn_on_unusual_options() const {
    if (use_expensive_statistics) {
        string dashes(79, '=');
        cerr << dashes << endl
             << ("WARNING! You have enabled extra statistics for "
            "merge-and-shrink heuristics.\n"
            "These statistics require a lot of time and memory.\n"
            "When last tested (around revision 3011), enabling the "
            "extra statistics\nincreased heuristic generation time by "
            "76%. This figure may be significantly\nworse with more "
            "recent code or for particular domains and instances.\n"
            "You have been warned. Don't use this for benchmarking!")
        << endl << dashes << endl;
    }
}

/**
 * Construct a var_set object.
 * @param var_set: the variable set
 * @param lv: the ref to the constructed variable set
 * @return true if all value associated with the var_set object can be computed
 */
bool MergeAndShrinkHeuristic::make_lattice_varset(const set<int>& var_set,
        LatticeVarset& lv) {
    lv.varset = var_set;
    /* Assume variables in any cluster are merged in REVERSE_LEVEL
     * TODO specify other merging order */
    vector<int> varorder(lv.varset.begin(), lv.varset.end());

    if (ratio_table.count(lv.varset)) {
        lv.ratio = ratio_table[lv.varset];
        if (lv.ratio != -1) {
            lv.total = total_count(varorder);
            lv.reach = lv.ratio * lv.total + 0.5;
            lv.diff = 1 - lv.ratio;
            if (!compute_diff_for(lv)) return false;
        } else {
            if (diff_table.count(lv.varset)) {
                assert(diff_table[lv.varset] == -1);
            }
        }
    } else {
        if (!reach_count(varorder, lv.reach)) return false; /* size too big */
        if (lv.reach != -1) {
            lv.total = total_count(varorder);
            lv.ratio = (double) lv.reach / lv.total;
            lv.diff = 1 - lv.ratio;
            if (!compute_diff_for(lv)) return false;
        } else {
            assert(lv.diff == -1 && lv.ratio == -1 && lv.total == -1);
        }
    }


    update_tables(lv.varset, lv.ratio, lv.diff);

    return true;
}

/**
 * enqueue a variable set
 * @param pq: the priority_queue
 * @param lv: the variable set
 */
void MergeAndShrinkHeuristic::enqueue(priority_queue<LatticeVarset>& pq,
        LatticeVarset& lv) {
    enqueued.insert(lv.varset);
    //    cerr << sum_total_enqueued << endl;
    //    cerr << sum_computation << endl;
    assert(lv.reach <= reach_bound);
    //    if (lv.reach <= reach_bound
    //            && sum_total_enqueued + lv.total <= sum_total_enqueued_bound) {
    pq.push(lv);
    sum_total_enqueued += lv.total;
    //    } else {
    //        if (lv.reach > reach_bound) {
    //            cerr << "space too large: " << lv.reach << ">" << reach_bound
    //                    << endl;
    //        } else {
    //            cerr << "total bound exceeds" << endl;
    //        }
    //    }
}


/**
 * compute $R_d$ (see the paper) for the variable set lv
 * @param lv: the variable set
 * @return true if $R_d$ can be computed
 */
bool MergeAndShrinkHeuristic::compute_diff_for(LatticeVarset& lv) {
    if (diff_table.count(lv.varset)) {
        lv.diff = diff_table[lv.varset];
        return true;
    }

    assert(abs(1 - lv.ratio - lv.diff) < 1e-7);
    for (int k = 1; k <= static_cast<int>(lv.varset.size()) / 2
            && sum_computation <= sum_computation_bound; k++) {
        set<set<int> > k_subsets;
        lv.enumerate_k_subsets(k, k_subsets);
        for (set<set<int> >::iterator i_set = k_subsets.begin();
                i_set != k_subsets.end()
                && sum_computation <= sum_computation_bound; i_set++) {
            set<int> S[2];
            S[0] = *i_set;
            //            set<int> S1(*i_set);
            //            cerr << S1 << " | ";
            //            set<int> S2;
            set_difference(lv.varset.begin(), lv.varset.end(), S[0].begin(),
                    S[0].end(), inserter(S[1], S[1].end()));
            //            cerr << S2 << endl;
            double r[2];
            for (int i = 0; i < 2; i++) {
                r[i] = -1;
                if (!ratio_table.count(S[i])) {
                    vector<int> varorder(S[i].begin(), S[i].end());
                    int temp_reach;
                    if (!reach_count(varorder, temp_reach)) return false;
                    /* size too big */
                    if (temp_reach == -1) {
                        ratio_table.insert(make_pair(S[i], -1));
                    } else {
                        ratio_table.insert(make_pair(S[i],
                                (double) temp_reach / total_count(varorder)));
                    }
                }
                r[i] = ratio_table[S[i]];
                if (r[i] == -1) {
                    lv.diff = -1;
                    return true;
                }
            }

            assert(lv.ratio != -1 && r[0] != -1 && r[1] != -1);
            if (r[0] * r[1] - lv.ratio < lv.diff) {
                lv.diff = r[0] * r[1] - lv.ratio;
                if (lv.diff < 1e-7) {
                    lv.diff = 0;
                    return true;
                }
            }
        }
    }
    if (sum_computation > sum_computation_bound) {
        lv.diff = -1;
    }
    return true;
}

/**
 * check all subsets of starting variable sets
 * @param var_set: the variable set
 * @param enqueued: the set of enqueued variable set
 * @param promising_subsets: the promising subsets
 * @return true always
 */
bool MergeAndShrinkHeuristic::subset_check(const set<int> & var_set,
        const set<set<int> >& enqueued,
        vector<LatticeVarset>& promising_subsets) {

    //    map<set<int>, double> local_min_diff_table;

    for (int k = 2; k < static_cast<int>(var_set.size())
            && sum_computation <= sum_computation_bound; k++) {
        set<set<int> > k_subsets;
        LatticeVarset::enumerate_k_subsets(var_set, k, k_subsets);
        int blowup = 0;
        for (set<set<int> >::iterator i_S = k_subsets.begin();
                i_S != k_subsets.end()
                && sum_computation <= sum_computation_bound; i_S++) {
            if (enqueued.count(*i_S)) continue;
            LatticeVarset lv_S;
            if (!(make_lattice_varset(*i_S, lv_S))) return false;
            if (lv_S.diff != -1) {
                if (lv_S.diff > 0) {
                    promising_subsets.push_back(lv_S);
                }
            } else {
                blowup++;
            }
        }
        if (blowup == static_cast<int>(k_subsets.size())) {

            break;
        }
    }
    return true;
}

// compare function used to sort sccs in kickstart()
bool compare_int_sets(const set<int> &lhs,
                 const set<int> &rhs) {
    return *(lhs.begin()) < *(rhs.begin());
}

/**
 * enqueue the starting subsets including sets of mutex groups, strongly
 * connected components
 * @return false if unsolvability detected; true otherwise
 */
bool MergeAndShrinkHeuristic::kickstart() {
    // Compute the strongly connected components of the causal graph
    /*
      NOTE: this implementation (cf. old repo ms-miasm in archive
      sievers-et-al-aaai2015.tar.bz2) indeed uses the legacy CG here but the
      normal one in merge_tree.cc
    */
    legacy_causal_graph::LegacyCausalGraph legacy_causal_graph(legacy_causal_graph::get_causal_graph(task.get()));
    vector<vector<int> > cg;
    cg.reserve(tasks::g_root_task->get_num_variables());
    for (int var = 0; var < tasks::g_root_task->get_num_variables(); ++var) {
        const vector<int> &succs = legacy_causal_graph.get_successors(var);
        cg.push_back(succs);
    }
    vector<vector<int>> scc_result(sccs::compute_maximal_sccs(cg));
    cout << "scc result: " << scc_result << endl;
    vector<set<int> > sccs;
    sccs.reserve(scc_result.size());
    for (size_t i = 0; i < scc_result.size(); ++i) {
        set<int> scc;
        scc.insert(scc_result[i].begin(), scc_result[i].end());
        sccs.push_back(scc);
    }
    sort(sccs.begin(), sccs.end(), compare_int_sets);

    vector<set<int> > kickstart_sets;
    map<int, vector<set<int> > > kickstart_set_map;

    for (int i = 0; i < static_cast<int>(tasks::g_mutex_groups.size()); i++) {
        kickstart_sets.push_back(tasks::g_mutex_groups[i]);
    }

    for (int i = 0; i < static_cast<int>(sccs.size()); i++) {
        kickstart_sets.push_back(sccs[i]);
    }

    for (int i = 0; i < static_cast<int>(kickstart_sets.size()); i++) {
        //        kickstart_sets.push_back(mutex_group[i]);
        if (!kickstart_set_map.count(kickstart_sets[i].size())) {
            kickstart_set_map.insert(
                    make_pair(kickstart_sets[i].size(), vector<set<int> >()));
        }
        kickstart_set_map[kickstart_sets[i].size()].push_back(kickstart_sets[i]);
    }

    kickstart_sets.clear();
    for (map<int, vector<set<int> > >::iterator i_m = kickstart_set_map.begin();
            i_m != kickstart_set_map.end(); i_m++) {
        for (int i = 0; i < static_cast<int>(i_m->second.size()); i++) {
            kickstart_sets.push_back(i_m->second[i]);
            cerr << kickstart_sets.back() << endl;
        }
    }


    //    for (map<int, vector<set<int> > >::iterator i_m = kickstart_set_map.begin();
    //            i_m != kickstart_set_map.end(); i_m++) {
    //        for (int i = 0; i < i_m->second.size(); i++) {
    //            kickstart_sets.push_back(i_m->second[i]);
    //        }
    //    }
    //
    //    for (int i = 0; i < kickstart_sets.size(); i++) {
    //        kickstart_sets.push_back(scc[i]);
    //        cerr << kickstart_sets.back() << endl;
    //    }
    //    if (scc.size() == 1) return true;
    //    cerr << scc.size() << endl;
    /* enqueue variable sets of strongly connected component in CG */
    for (int i = 0; i < static_cast<int>(kickstart_sets.size()); i++) {
        cerr << "sum_computation_bound " << sum_computation_bound << endl;
        cerr << "sum_computation " << sum_computation << endl;
        sum_computation_bound = sum_computation
                + (kickstart_sum_computation_bound - sum_computation) / (kickstart_sets.size() - i);
        cerr << "sum_computation_bound " << sum_computation_bound << endl;
        cerr << "sum_computation " << sum_computation << endl;
        cerr << "checking " << kickstart_sets[i] << endl;
        if (enqueued.count(kickstart_sets[i])) continue;

        LatticeVarset lv;
        /* compute reachable count and
         * if detect unsolvability in the space then terminate */
        if (!make_lattice_varset(kickstart_sets[i], lv)) return false;
        if (lv.diff != -1) {
            //            update_tables(lv.varset, lv.ratio, lv.diff);
            //            if (!compute_diff_for(lv)) return false;

            cerr << "pushing kickstart subset " << lv.varset
                    << " " << lv.ratio
                    << "=" << lv.reach
                    << "/" << lv.total
                    << " " << lv.diff << endl;

            /* insert the R for this set */
            //            ratio_table.insert(make_pair(lv.varset, lv.ratio));
            /* enqueue the subset into the priority queue */
            enqueue(pq, lv);
        } else {
            cerr << "skip " << lv.varset << " " << lv.ratio
                    << "=" << lv.reach
                    << "/" << lv.total
                    << " " << lv.diff << endl;
        }

        vector<LatticeVarset> promising_subsets;
        subset_check(lv.varset, enqueued, promising_subsets);
        for (int j = 0; j < static_cast<int>(promising_subsets.size()); j++) {

            cerr << "pushing promising subset " << promising_subsets[j].varset << endl;
            assert(!enqueued.count(promising_subsets[j].varset));
            enqueue(pq, promising_subsets[j]);
        }


    }
    return true;
}

/**
 * subset searching
 * @return false if unsolvability detected; true otherwise
 */
bool MergeAndShrinkHeuristic::search_promising_clusters() {
    sum_total_enqueued = 0;
    sum_computation = 0;
    ratio_table.clear();
    diff_table.clear();
    enqueued.clear();
    /* the table stores the R for each subset */
    ratio_table.insert(make_pair(set<int>(), 1));

    enqueued.insert(set<int>());

    /* enqueue atomic variable sets */
    for (int i = 0; i < static_cast<int>(tasks::g_root_task->get_num_variables())
            && sum_computation <= sum_computation_bound; ++i) {

        /* construct the ATOMIC set */
        set<int> var_set;
        var_set.insert(i);

        LatticeVarset lv;
        /* compute reachable count and
         * if detect unsolvability in the space then terminate */
        if (!make_lattice_varset(var_set, lv)) return false;
        assert(lv.diff != -1 && lv.ratio != -1 && lv.reach != -1);
        /* insert the R for this set */
        //        update_tables(lv.varset, lv.ratio, lv.diff);
        //        ratio_table.insert(make_pair(lv.varset, lv.ratio));
        /* enqueue the subset into the priority queue */
        enqueue(pq, lv);
        /* insert weight for this atomic set */
        weight.insert(make_pair(lv.varset, -log(lv.ratio)));
    }

    kickstart_sum_computation_bound = sum_computation_bound * 9 / 10;
    arbitary_sum_computation_bound = sum_computation_bound / 10;

    if (!kickstart()) return false;

    /* reset computation */
    //    sum_computation = 0;
    sum_computation_bound = arbitary_sum_computation_bound
            + kickstart_sum_computation_bound
            - sum_computation;

    while (!pq.empty()) {
        LatticeVarset popped = pq.top();

        if (popped.diff > 0) {
            weight.insert(make_pair(popped.varset, -log(popped.ratio)));
            cerr << popped.varset << " " << popped.ratio
                    << "=" << popped.reach
                    << "/" << popped.total << endl;
        }
        pq.pop();

        for (int i = 0; i < static_cast<int>(tasks::g_root_task->get_num_variables())
                && sum_computation <= sum_computation_bound; ++i) {
            //            cerr << sum_computation << " <= "
            //                    << sum_computation_bound << endl;

            if (popped.varset.count(i)) continue;

            //            if (sum_total_enqueued +
            //                    popped.total * g_variable_domain[i]
            //                    > sum_total_enqueued_bound) continue;
            //            if (popped.reach * g_variable_domain[i]
            //                    > reach_bound) continue;

            set<int> push_set(popped.varset);
            push_set.insert(i);
            //            cerr << "pushing " << push_set << "?" << endl;
            if (enqueued.count(push_set)) continue;

            LatticeVarset push_attempt;
            if (!make_lattice_varset(push_set, push_attempt)) return false;
            if (push_attempt.diff != -1) {
                //                cerr << "pushing " << push_set << endl;

                //                update_tables(push_attempt.varset, push_attempt.ratio,
                //                        push_attempt.diff);
                //                ratio_table.insert(make_pair(push_set, push_attempt.ratio));
                //                if (!compute_diff_for(push_attempt)) return false;
                /* enqueue the subset into the priority queue */
                enqueue(pq, push_attempt);
            }
        }
    }
    return true;
}

/**
 * Compute the number of the necessary states in the abstraction on a variable set
 * the abstraction is constructed by M&S using the merging order given
 * @param order: the merging order
 * @param reach: the number of the necessary states
 * @return false if unsolvability detected; true otherwise
 */
bool MergeAndShrinkHeuristic::reach_count(const vector<int>& order,
        int& reach) {

    const int k = order.size();

    vector<Abstraction*> atomic_abstraction;
    Abstraction::build_atomic_abstractions(atomic_abstraction, labels, true);

    vector<Abstraction*> intermediate_abstraction;
    intermediate_abstraction.push_back(atomic_abstraction[order[0]]);


    int a, b;
    vector<int> A;
    A.push_back(order[0]);

    int computation_amount = 0;

    for (int i = 1; i < k; i++) {

        intermediate_abstraction.back()->compute_reachable_count(a);
        atomic_abstraction[order[i]]->compute_reachable_count(b);

        computation_amount += intermediate_abstraction.back()->size();
        computation_amount += atomic_abstraction[order[i]]->size();

        if (a * b > reach_bound) {
            cerr << "can't merge " << A << " (" << a << ") and [" << order[i]
                    << "] (" << b << ") without exceeding "
                    << "the limit " << reach_bound << endl;
            reach = -1;
            break;
        }

        intermediate_abstraction.back()->normalize();
        atomic_abstraction[order[i]]->normalize();

        intermediate_abstraction.push_back(new CompositeAbstraction(
                labels,
                intermediate_abstraction.back(),
                atomic_abstraction[order[i]],
                true));
        A.push_back(order[i]);
        if (!intermediate_abstraction.back()->is_solvable()) return false;
    }
    intermediate_abstraction.back()->compute_reachable_count(reach);
    computation_amount += intermediate_abstraction.back()->size();

    for (int i = 0; i < static_cast<int>(atomic_abstraction.size()); i++) {
        atomic_abstraction[i]->release_memory();
        delete atomic_abstraction[i];
    }

    for (int i = 1; i < static_cast<int>(intermediate_abstraction.size()); i++) {
        intermediate_abstraction[i]->release_memory();
        delete intermediate_abstraction[i];
    }

    sum_computation += computation_amount;
    //    sum_total += accumulated_reach;
    //    cerr << "\r%d" << sum_total;

    return true;
}

/**
 * compute the total number of states in the abstraction on the variable set given
 * @param v: the variable set
 * @return the total number of states
 */
int MergeAndShrinkHeuristic::total_count(const vector<int>& v) {
    int ret = 1;
    for (int i = 0; i < static_cast<int>(v.size()); i++) ret *= tasks::g_root_task->get_variable_domain_size(v[i]);

    return ret;
}

/**
 * Update the look-up table for values of $R$ and $R_d$ for variable set varset
 * @param varset: the variable set
 * @param ratio:  $R$ as in the paper
 * @param diff: $R_d$ as in the paper
 */
void MergeAndShrinkHeuristic::update_tables(const set<int> varset,
        const double ratio, const double diff) {
    if (!ratio_table.count(varset)) {
        ratio_table.insert(make_pair(varset, ratio));
    } else {
        assert(abs(ratio - ratio_table[varset]) <= 1e-7);
    }
    if (!diff_table.count(varset)) {
        //        cerr << varset << "'s diff is " << diff << " ?= " << endl;
        diff_table.insert(make_pair(varset, diff));
    } else {
        //        cerr << varset << diff << " ?= " << diff_table[varset] << endl;
        assert(abs(diff - diff_table[varset]) <= 1e-7);
    }
}

Abstraction *MergeAndShrinkHeuristic::build_abstraction() {
    // TODO: We're leaking memory here in various ways. Fix this.
    //       Don't forget that build_atomic_abstractions also
    //       allocates memory.

    // vector of all abstractions. entries with 0 have been merged.
    vector<Abstraction *> all_abstractions;
    all_abstractions.reserve(tasks::g_root_task->get_num_variables() * 2 - 1);
    Abstraction::build_atomic_abstractions(all_abstractions, labels);

    cout << "Shrinking atomic abstractions..." << endl;
    for (size_t i = 0; i < all_abstractions.size(); ++i) {
        all_abstractions[i]->compute_distances();
        if (!all_abstractions[i]->is_solvable())
            return all_abstractions[i];
        shrink_strategy->shrink_atomic(*all_abstractions[i]);
    }

    int maximum_intermediate_size = 0;
    for (size_t i = 0; i < all_abstractions.size(); ++i) {
        int size = all_abstractions[i]->size();
        if (size > maximum_intermediate_size) {
            maximum_intermediate_size = size;
        }
    }
    int iteration_counter = 0;
    bool still_perfect = true;
    vector<pair<int, int>> merge_order;
    int num_attempts_merging_for_symmetries = 0;
    int num_imperfect_shrinking_merging_for_symmetries = 0;
    int num_pruning_merging_for_symmetries = 0;
    int num_failed_merging_for_symmetries = 0;
    bool merging_for_symmetries = true;
    bool currently_shrink_perfect_for_symmetries = true;
    bool currently_prune_perfect_for_symmetries = true;

    cout << "Merging abstractions..." << endl;

    while (!merge_strategy->done()) {
        pair<int, int> next_systems = merge_strategy->get_next(all_abstractions);
        if (merge_strategy->ended_merging_for_symmetries()) {
            merging_for_symmetries = false;
            if (!currently_shrink_perfect_for_symmetries) {
                ++num_imperfect_shrinking_merging_for_symmetries;
            }
            if (!currently_prune_perfect_for_symmetries) {
                ++num_pruning_merging_for_symmetries;
            }
            if (!currently_shrink_perfect_for_symmetries ||
                !currently_prune_perfect_for_symmetries) {
                ++num_failed_merging_for_symmetries;
            }
        }
        if (merge_strategy->started_merging_for_symmetries()) {
            ++num_attempts_merging_for_symmetries;
            merging_for_symmetries = true;
            currently_shrink_perfect_for_symmetries = true;
            currently_prune_perfect_for_symmetries = true;
        }
        merge_order.push_back(next_systems);
        int system_one = next_systems.first;
        Abstraction *abstraction = all_abstractions[system_one];
        assert(abstraction);
        int system_two = next_systems.second;
        assert(system_one != system_two);
        Abstraction *other_abstraction = all_abstractions[system_two];
        assert(other_abstraction);

        // Note: we do not reduce labels several times for the same abstraction
        bool reduced_labels = false;
        if (shrink_strategy->reduce_labels_before_shrinking()) {
            labels->reduce(make_pair(system_one, system_two), all_abstractions);
            reduced_labels = true;
            abstraction->normalize();
            other_abstraction->normalize();
            abstraction->statistics(use_expensive_statistics);
            other_abstraction->statistics(use_expensive_statistics);
        }

        // distances need to be computed before shrinking
        bool pruned_unreachable1 = abstraction->compute_distances();
        bool pruned_unreachable2 = other_abstraction->compute_distances();
        if (!abstraction->is_solvable())
            return abstraction;
        if (!other_abstraction->is_solvable())
            return other_abstraction;
        if (merging_for_symmetries && (pruned_unreachable1 || pruned_unreachable2)) {
            currently_prune_perfect_for_symmetries = false;
        }

        int previous_size1 = abstraction->get_size();
        int previous_size2 = other_abstraction->get_size();
        shrink_strategy->shrink_before_merge(*abstraction, *other_abstraction);
        // TODO: Make shrink_before_merge return a pair of bools
        //       that tells us whether they have actually changed,
        //       and use that to decide whether to dump statistics?
        // (The old code would print statistics on abstraction iff it was
        // shrunk. This is not so easy any more since this method is not
        // in control, and the shrink strategy doesn't know whether we want
        // expensive statistics. As a temporary aid, we just print the
        // statistics always now, whether or not we shrunk.)
        abstraction->statistics(use_expensive_statistics);
        other_abstraction->statistics(use_expensive_statistics);
        bool shrunk = abstraction->get_size() < previous_size1 || other_abstraction->get_size() < previous_size2;
        if (merging_for_symmetries && currently_shrink_perfect_for_symmetries && shrunk) {
            currently_shrink_perfect_for_symmetries = false;
        }

        const vector<double> &miss_qualified_states_ratios =
            shrink_strategy->get_miss_qualified_states_ratios();
        int size = miss_qualified_states_ratios.size();
        if (size >= 2 && still_perfect &&
            (miss_qualified_states_ratios[size - 1]
             || miss_qualified_states_ratios[size - 2])) {
            // The test for size >= 2 is to ensure we actually record
            // this kind of statistics -- currently only with bisimulation
            // shrinking.
            cout << "not perfect anymore in iteration " << iteration_counter << endl;
            still_perfect = false;
        }

        if (!reduced_labels) {
            labels->reduce(make_pair(system_one, system_two), all_abstractions);
        }
        abstraction->normalize();
        other_abstraction->normalize();
        if (!reduced_labels) {
            // only print statistics if we just possibly reduced labels
            other_abstraction->statistics(use_expensive_statistics);
            abstraction->statistics(use_expensive_statistics);
        }

        Abstraction *new_abstraction = new CompositeAbstraction(labels,
                                                                abstraction,
                                                                other_abstraction);

        abstraction->release_memory();
        other_abstraction->release_memory();

        new_abstraction->statistics(use_expensive_statistics);

        all_abstractions[system_one] = 0;
        all_abstractions[system_two] = 0;
        all_abstractions.push_back(new_abstraction);

        int abs_size = new_abstraction->size();
        if (abs_size > maximum_intermediate_size) {
            maximum_intermediate_size = abs_size;
        }
        ++iteration_counter;
    }

    assert(static_cast<int>(all_abstractions.size()) == tasks::g_root_task->get_num_variables() * 2 - 1);
    Abstraction *final_abstraction = 0;
    for (size_t i = 0; i < all_abstractions.size(); ++i) {
        if (all_abstractions[i]) {
            if (final_abstraction) {
                cerr << "Found more than one remaining abstraction!" << endl;
                exit_with(ExitCode::CRITICAL_ERROR);
            }
            final_abstraction = all_abstractions[i];
            assert(i == all_abstractions.size() - 1);
        }
    }

    final_abstraction->compute_distances();

    cout << "Maximum intermediate abstraction size: "
         << maximum_intermediate_size << endl;
    const vector<double> &miss_qualified_states_ratios =
        shrink_strategy->get_miss_qualified_states_ratios();
    cout << "Course of miss qualified states shrinking: "
         << miss_qualified_states_ratios << endl;
    double summed_values = 0;
    for (double value : miss_qualified_states_ratios) {
        summed_values += value;
    }
    size_t number_of_shrinks = miss_qualified_states_ratios.size();
    double average_imperfect_shrinking = 0;
    if (number_of_shrinks) {
        average_imperfect_shrinking = summed_values / static_cast<double>(number_of_shrinks);
    }
    cout << "Average imperfect shrinking: " << average_imperfect_shrinking << endl;
    cout << "Merge order: [";
    bool linear_order = true;
    int next_index = task_proxy.get_variables().size();
    for (size_t i = 0; i < merge_order.size(); ++i) {
        pair<int, int> merge = merge_order[i];
        cout << "(" << merge.first << ", " << merge.second << ")";
        if (i != merge_order.size() - 1) {
            cout << ", ";
        }
        if (linear_order && i != 0) {
            if (merge.first != next_index && merge.second != next_index) {
                linear_order = false;
            }
            ++next_index;
        }
    }
    cout << "]" << endl;
    if (linear_order) {
        cout << "Linear merge order" << endl;
    } else {
         cout << "Non-linear merge order" << endl;
    }

    cout << "Number of attempts to merge for symmetries: "
         << num_attempts_merging_for_symmetries << endl;
    cout << "Number of times non-perfect shrinking interfered merging for symmetries: "
         << num_imperfect_shrinking_merging_for_symmetries << endl;
    cout << "Number of times pruning interfered merging for symmetries: "
         << num_pruning_merging_for_symmetries << endl;
    cout << "Number of times merging for symmetries failed for any reason: "
         << num_failed_merging_for_symmetries << endl;

    if (!final_abstraction->is_solvable())
        return final_abstraction;

    final_abstraction->statistics(use_expensive_statistics);
    final_abstraction->release_memory();

    return final_abstraction;
}

void MergeAndShrinkHeuristic::initialize() {
    Timer timer;
    cout << "Initializing merge-and-shrink heuristic..." << endl;
    starting_peak_memory = get_peak_memory_in_kb();
    dump_options();
    warn_on_unusual_options();

    task_properties::verify_no_axioms(task_proxy);

    cout << "Building abstraction..." << endl;
    final_abstraction = build_abstraction();
    if (!final_abstraction->is_solvable()) {
        cout << "Abstract problem is unsolvable!" << endl;
    }

    cout << "Final transition system size: " << final_abstraction->size() << endl;

    cout << "Done initializing merge-and-shrink heuristic [" << timer << "]"
         << endl;// << "initial h value: " << compute_heuristic(g_initial_state()) << endl;
    cout << "Estimated peak memory for abstraction: " << final_abstraction->get_peak_memory_estimate() << " bytes" << endl;
    report_peak_memory_delta(true);
}

int MergeAndShrinkHeuristic::compute_heuristic(const GlobalState &state) {
    int cost = final_abstraction->get_cost(state);
    if (cost == -1)
        return DEAD_END;
    return cost;
}

static Heuristic *_parse(OptionParser &parser) {
    parser.document_synopsis("Merge-and-shrink heuristic", "");
    parser.document_language_support("action costs", "supported");
    parser.document_language_support("conditional effects", "supported (but see note)");
    parser.document_language_support("axioms", "not supported");
    parser.document_property("admissible", "yes");
    parser.document_property("consistent", "yes");
    parser.document_property("safe", "yes");
    parser.document_property("preferred operators", "no");
    parser.document_note(
        "Note",
        "Conditional effects are supported directly. Note, however, that "
        "for tasks that are not factored (in the sense of the JACM 2014 "
        "merge-and-shrink paper), the atomic abstractions on which "
        "merge-and-shrink heuristics are based are nondeterministic, "
        "which can lead to poor heuristics even when no shrinking is "
        "performed.");

    parser.add_option<MergeStrategy *>(
        "merge_strategy",
        "merge strategy; choose between merge_linear and merge_dfp",
        "merge_linear");

    parser.add_option<ShrinkStrategy *>(
        "shrink_strategy",
        "shrink strategy; "
        "try one of the following:",
        "shrink_fh(max_states=50000, max_states_before_merge=50000, shrink_f=high, shrink_h=low)");
    vector<pair<string, string>> shrink_value_explanations;
    shrink_value_explanations.push_back(
        make_pair("shrink_fh(max_states=N)",
                  "f-preserving abstractions from the "
                  "Helmert/Haslum/Hoffmann ICAPS 2007 paper "
                  "(called HHH in the IJCAI 2011 paper by Nissim, "
                  "Hoffmann and Helmert). "
                  "Here, N is a numerical parameter for which sensible values "
                  "include 1000, 10000, 50000, 100000 and 200000. "
                  "Combine this with the default linear merge strategy "
                  "CG_GOAL_LEVEL to match the heuristic "
                  "in the paper."));
    shrink_value_explanations.push_back(
        make_pair("shrink_bisimulation(max_states=infinity, threshold=1, greedy=true, initialize_by_h=false, group_by_h=false)",
                  "Greedy bisimulation without size bound "
                  "(called M&S-gop in the IJCAI 2011 paper by Nissim, "
                  "Hoffmann and Helmert). "
                  "Combine this with the linear merge strategy "
                  "REVERSE_LEVEL to match "
                  "the heuristic in the paper. "));
    shrink_value_explanations.push_back(
        make_pair("shrink_bisimulation(max_states=N, greedy=false, initialize_by_h=true, group_by_h=true)",
                  "Exact bisimulation with a size limit "
                  "(called DFP-bop in the IJCAI 2011 paper by Nissim, "
                  "Hoffmann and Helmert), "
                  "where N is a numerical parameter for which sensible values "
                  "include 1000, 10000, 50000, 100000 and 200000. "
                  "Combine this with the linear merge strategy "
                  "REVERSE_LEVEL to match "
                  "the heuristic in the paper."));
    parser.document_values("shrink_strategy", shrink_value_explanations);

    vector<string> label_reduction_method;
    label_reduction_method.push_back("NONE");
    label_reduction_method.push_back("OLD");
    label_reduction_method.push_back("TWO_ABSTRACTIONS");
    label_reduction_method.push_back("ALL_ABSTRACTIONS");
    label_reduction_method.push_back("ALL_ABSTRACTIONS_WITH_FIXPOINT");
    parser.add_enum_option("label_reduction_method", label_reduction_method,
                           "label reduction method: "
                           "none: no label reduction will be performed "
                           "old: emulate the label reduction as desribed in the "
                           "IJCAI 2011 paper by Nissim, Hoffmann and Helmert."
                           "two_abstractions: compute the 'combinable relation' "
                           "for labels only for the two abstractions that will "
                           "be merged next and reduce labels."
                           "all_abstractions: compute the 'combinable relation' "
                           "for labels once for every abstraction and reduce "
                           "labels."
                           "all_abstractions_with_fixpoint: keep computing the "
                           "'combinable relation' for labels iteratively for all "
                           "abstractions until no more labels can be reduced.",
                           "ALL_ABSTRACTIONS_WITH_FIXPOINT");
    vector<string> label_reduction_system_order;
    label_reduction_system_order.push_back("REGULAR");
    label_reduction_system_order.push_back("REVERSE");
    label_reduction_system_order.push_back("RANDOM");
    parser.add_enum_option("label_reduction_system_order", label_reduction_system_order,
                           "order of transition systems for the label reduction methods "
                           "that iterate over the set of all abstractions. only useful "
                           "for the choices all_abstractions and all_abstractions_with_fixpoint "
                           "for the option label_reduction_method.", "RANDOM");
    parser.add_option<bool>("expensive_statistics",
                            "show statistics on \"unique unlabeled edges\" (WARNING: "
                            "these are *very* slow, i.e. too expensive to show by default "
                            "(in terms of time and memory). When this is used, the planner "
                            "prints a big warning on stderr with information on the performance impact. "
                            "Don't use when benchmarking!)",
                            "false");
    parser.add_option<int>("sum_computation_bound",
                           "TODO: description",
                           "1000000");
    parser.add_option<int>("miasm_max_states",
                           "should be the same as shrink_strategy->get_max_states() "
                           "unless infinity",
                           "50000");

    /* HACK: add previous option of heuristics back in.*/
    vector<string> cost_types;
    cost_types.push_back("NORMAL");
    cost_types.push_back("ONE");
    cost_types.push_back("PLUSONE");
    parser.add_enum_option("cost_type",
                           cost_types,
                           "operator cost adjustment",
                           "NORMAL");

    Heuristic::add_options_to_parser(parser);
    Options opts = parser.parse();

    if (parser.dry_run()) {
        return 0;
    } else {
        if (opts.get_enum("label_reduction_method") == 1
            && opts.get<MergeStrategy *>("merge_strategy")->name() != "linear") {
            parser.error("old label reduction is only correct when used with a "
                         "linear merge strategy!");
        }
        MergeAndShrinkHeuristic *result = new MergeAndShrinkHeuristic(opts);
        return result;
    }
}

static Plugin<Heuristic> _plugin("nlr_merge_and_shrink", _parse);
}
