#include "merge_tree.h"


#include <cstdio>


#include <sstream>
#include <string>
#include <iostream>




#include "util/setop.h"

#include "../globals.h"
#include "../tasks/root_task.h"
#include "../task_utils/causal_graph.h"

#include "../utils/logging.h"

using namespace causal_graph;
using namespace std;
using namespace set_operation;
using namespace utils;

namespace mas_new_lr {
tree<node_t> MergeTree::merge_tree_clusters(
        const set<set<int> >& clusters,
        const ClusterExternalMergeStrategyType external,
        const ClusterInternalMergeStrategyType internal) {
    /* the merge tree */
    tree<node_t> merge_tree;

    vector<int> remaining_variables;
    map<int, set<int> > v2c;
    for (set<set<int> >::iterator i_c = clusters.begin();
            i_c != clusters.end(); i_c++) {
        for (set<int>::iterator i_v = i_c->begin(); i_v != i_c->end(); i_v++) {
            v2c.insert(make_pair(*i_v, *i_c));
        }
    }

    set<int> goal_var;
    for (int i = 0; i < tasks::g_root_task->get_num_goals(); ++i) {
        goal_var.insert(tasks::g_root_task->get_goal_fact(i).var);
    }

    if (external == CLUSTER_EXTERNAL_LINEAR_LARGE_CG_GOAL_LEVEL
            || external == CLUSTER_EXTERNAL_LINEAR_SMALL_CG_GOAL_LEVEL) {
        /* clusters sorted in sizes, from small to large */
        map<int, vector<int> > size2cluster;
        for (set<set<int> >::iterator i_c = clusters.begin();
                i_c != clusters.end(); i_c++) {
            if (!size2cluster.count(i_c->size())) {
                size2cluster.insert(make_pair(i_c->size(), vector<int>()));
            }
            for (set<int>::iterator i_v = i_c->begin(); i_v != i_c->end();
                    i_v++) {
                size2cluster[i_c->size()].push_back(*i_v);
            }
        }
        if (external == CLUSTER_EXTERNAL_LINEAR_LARGE_CG_GOAL_LEVEL) {
            /* variables in smaller cluster are closer to the front
             * variables in clusters of the same sizes are sorted by level 
             * and variables with lower (smaller) level are closer to the front */
            vector <vector<int> > v;
            for (map<int, vector<int> >::iterator i_c = size2cluster.begin();
                    i_c != size2cluster.end(); i_c++) {
                //            cerr << i_c->second << endl;
                v.push_back(i_c->second);
            }
            for (int i = v.size() - 1; i >= 0; i--) {
                for (int j = v[i].size() - 1; j >= 0; j--) {
                    remaining_variables.push_back(v[i][j]);
                }
            }
        } else if (external == CLUSTER_EXTERNAL_LINEAR_SMALL_CG_GOAL_LEVEL) {
            /* variables in smaller non-atomic cluster are closer to the front
             * variables in clusters of the same sizes are sorted by level 
             * and variables with lower(smaller) level are closer to the front */

            for (map<int, vector<int> >::iterator i_c = size2cluster.begin();
                    i_c != size2cluster.end(); i_c++) {
                if (i_c->first == 1) continue;
                for (int i = 0; i < static_cast<int>(i_c->second.size()); i++) {
                    remaining_variables.push_back(i_c->second[i]);
                }
            }
            //            vector <vector<int> > v;
            for (map<int, vector<int> >::iterator i_c = size2cluster.begin();
                    i_c != size2cluster.end(); i_c++) {
                if (i_c->first > 1) continue;
                for (int i = i_c->second.size() - 1; i >= 0; i--) {
                    remaining_variables.push_back(i_c->second[i]);
                }
            }
        }


        cerr << remaining_variables << endl;
        //        exit(0);
        //        remaining_variables.clear();
        //        for (int i = g_variable_domain.size() - 1; i >= 0; i--) {
        //            remaining_variables.push_back(i);
        //        }
        //        remaining_variables.clear();
        //        for (int i = 0; i < g_variable_domain.size(); i++) {
        //            remaining_variables.push_back(i);
        //        }
        //        cerr << remaining_variables << endl;
        //        exit(0);
    } else if (external == CLUSTER_EXTERNAL_LINEAR_CG_GOAL_REVERSE_LEVEL) {
        for (int i = 0; i < static_cast<int>(tasks::g_root_task->get_num_variables()); i++) {
            remaining_variables.push_back(i);
        }
    } else {
        cerr << "unknown ClusterExternalMergeStrategyType";
        exit(0);
    }

    if (external == CLUSTER_EXTERNAL_LINEAR_LARGE_CG_GOAL_LEVEL
            || external == CLUSTER_EXTERNAL_LINEAR_SMALL_CG_GOAL_LEVEL
            || external == CLUSTER_EXTERNAL_LINEAR_CG_GOAL_REVERSE_LEVEL) {
        //        ClusterCausalGraph ccg(clusters);
        /* the CG predecessors of the merged variables */
        set<int> predecessors;
        while (remaining_variables.size()) {
            set<int> chosen_one;
            find_pred_cluster(v2c, remaining_variables, predecessors,
                    chosen_one);
            if (!chosen_one.size()) {
                find_goal_cluster(v2c, goal_var, remaining_variables,
                        predecessors, chosen_one);
            }
            if (merge_tree.size() == 0) {
                merge_tree
                        = merge_tree_cluster_internal(chosen_one, internal);
            } else {
                tree<node_t> left_subtree(merge_tree);
                tree<node_t> right_subtree
                        = merge_tree_cluster_internal(chosen_one, internal);
                //                    right_subtree.insert(right_subtree.begin(), chosen_one);
                merge_tree.clear();
                merge_tree = merge_subtrees(left_subtree, right_subtree);

            }
        }
    }

    return merge_tree;
}

void MergeTree::find_goal_cluster(map<int, set<int> >& v2c,
        const set<int> goal_var, vector<int>& rv, set<int>& predecessors,
        set<int>& chosen_cluster) {
    cerr << "checking clusters for GOAL:\n";
    for (int i = 0; i < static_cast<int>(rv.size()); i++) {
        assert(v2c.count(rv[i]));
        cerr << v2c[rv[i]] << " ";

        if (!goal_var.count(rv[i])) {
            cerr << "is not a goal\n";
            continue;
        }
        cerr << "is a goal\n";
        /* copy the cluster */
        chosen_cluster = v2c[rv[i]];
        /* update predecessor clusters */
        update_cluster_pred(rv, chosen_cluster, predecessors);

        return;
    }
}

void MergeTree::find_pred_cluster(map<int, set<int> >& v2c,
        vector<int>& rv, set<int>& predecessors, set<int>& chosen_cluster) {

    cerr << "checking clusters for CG:\n";
    for (int i = 0; i < static_cast<int>(rv.size()); i++) {
        assert(v2c.count(rv[i]));
        cerr << v2c[rv[i]] << " ";
        if (!predecessors.count(rv[i])) {
            cerr << "is not a predecessor\n";
            continue;
        }
        cerr << "is a predecessor\n";
        /* copy the cluster */
        chosen_cluster = v2c[rv[i]];
        /* update predecessor clusters */
        update_cluster_pred(rv, chosen_cluster, predecessors);
        return;
    }
}

void MergeTree::update_cluster_pred(vector<int>& rv,
        set<int>& chosen_cluster, set<int>& predecessors) {
    /* update predecessors */
    /*
      NOTE: this implementation (cf. old repo ms-miasm in archive
      sievers-et-al-aaai2015.tar.bz2) indeed uses the normal CG here but the
      legacy one in merge_and_shrink_heuristic.cc
    */
    const CausalGraph &cg = get_causal_graph(tasks::g_root_task.get());
    for (set<int>::iterator i_v = chosen_cluster.begin();
            i_v != chosen_cluster.end(); i_v++) {
        vector<int> pred_v = cg.get_predecessors(*i_v);
        for (int i = 0; i < static_cast<int>(pred_v.size()); i++) {
            predecessors.insert(pred_v[i]);
        }
    }
    vector<int> new_remaining_variables;
    for (int i = 0; i < static_cast<int>(rv.size()); i++) {
        if (chosen_cluster.count(rv[i])) continue;
        new_remaining_variables.push_back(rv[i]);
    }
    rv = new_remaining_variables;
}

tree<node_t> MergeTree::merge_subtrees(tree<node_t>& left,
        tree<node_t>& right) {
    node_t merged_node;
    get_union(*(left.begin()), *(right.begin()), merged_node);

    tree<node_t> merged_tree;

    merged_tree.insert(merged_tree.begin(), merged_node);

    merged_tree.append_child(merged_tree.begin(), left.begin());
    merged_tree.append_child(merged_tree.begin(), right.begin());

    return merged_tree;
}

tree<node_t> MergeTree::merge_tree_cluster_internal(const set<int>& cluster,
        const ClusterInternalMergeStrategyType internal) {

    tree<node_t> internal_tree;

    vector<int> cluster_vector;
    for (set<int>::iterator i = cluster.begin();
            i != cluster.end(); i++) {
        cluster_vector.push_back(*i);
    }
    vector<int> cluster_order;
    if (internal == CLUSTER_INTERNAL_LINEAR_REVERSE_LEVEL) {
        for (int i = 0; i < static_cast<int>(cluster_vector.size()); i++) {
            cluster_order.push_back(cluster_vector[i]);
        }
    } else if (internal == CLUSTER_INTERNAL_LINEAR_LEVEL) {
        for (int i = cluster_vector.size() - 1; i >= 0; i--) {
            cluster_order.push_back(cluster_vector[i]);
        }
    }
    for (int i = 0; i < static_cast<int>(cluster_order.size()); i++) {
        set<int> s;
        s.insert(cluster_order[i]);

        if (internal_tree.size() == 0) {
            internal_tree.insert(internal_tree.begin(), s);
        } else {
            tree<node_t> left_subtree(internal_tree);
            tree<node_t> right_subtree;
            right_subtree.insert(right_subtree.begin(), s);
            internal_tree.clear();
            internal_tree = merge_subtrees(left_subtree, right_subtree);

        }
    }

    return internal_tree;
}

tree<node_t> MergeTree::readin_merge_tree_recursive() {
    ifstream fin_mt_rec("merge_tree_umc.temp");

//    if (fopen("asdf.temp", "r") == NULL) {
//        cerr << "shit\n";
//    } else {
//        cerr << "good\n";
//    }
//
//    if (fin_mt_rec.good()) {
//        cerr << "fin_mt_rec good\n";
//    } else {
//        cerr << "fin_mt_rec bad\n";
//    }

    //    if (fin_mt_rec.good()){
    //        
    //    }
    return readin_branching_recursive(fin_mt_rec);
}

tree<node_t> MergeTree::readin_branching_recursive(ifstream& fin_mt_rec) {
    string line;
    getline(fin_mt_rec, line);
    tree<node_t> subtree;
    tree<node_t> left_subtree;
    tree<node_t> right_subtree;
    int c = -1;
    istringstream iss(line);
    cerr << "line: " << line << endl;
    iss >> c;
    cerr << "c: " << c << endl;
    vector<int> v;
    int var;
    while (iss >> var) {
        v.push_back(var);
    }
    cerr << v << endl;
    assert(c != -1);
    if (c > 1) {
        left_subtree = readin_branching_recursive(fin_mt_rec);
    } else {
        set<int> s;
        s.insert(v[0]);
        left_subtree.insert(left_subtree.begin(), s);
    }

    if (c < static_cast<int>(v.size()) - 1) {
        right_subtree = readin_branching_recursive(fin_mt_rec);
    } else {
        set<int> s;
        s.insert(v.back());
        right_subtree.insert(right_subtree.begin(), s);
    }
    subtree.clear();
    subtree = merge_subtrees(left_subtree, right_subtree);
    return subtree;
}
//void MergeTree::find_pred_cluster(vector<set<int> >& rm_cluster,
//        ClusterCausalGraph& ccg, set<int>& chosen_one,
//        set<set<int> >& pred_cluster) {
//
//    cerr << "checking clusters for CG:\n";
//    for (int i = 0; i < rm_cluster.size(); i++) {
//        cerr << rm_cluster[i] << " ";
//        if (!pred_cluster.count(rm_cluster[i])) {
//            cerr << "is not a predecessor\n";
//            continue;
//        }
//        cerr << "is a predecessor\n";
//        /* copy the cluster */
//        chosen_one = rm_cluster[i];
//        /* update predecessor clusters */
//        update_cluster_pred(ccg, rm_cluster, i, pred_cluster);
//        return;
//    }
//}

//void MergeTree::find_goal_cluster(vector<set<int> >& rm_cluster,
//        ClusterCausalGraph& ccg, set<int>& chosen_one,
//        set<set<int> >& pred_cluster) {
//    cerr << "checking clusters for GOAL:\n";
//    for (int i = 0; i < rm_cluster.size(); i++) {
//        cerr << rm_cluster[i] << " ";
//        if (!ccg.goal_cluster.count(rm_cluster[i])) {
//            cerr << "is not a goal\n";
//            continue;
//        }
//        cerr << "is a goal\n";
//        /* copy the cluster */
//        chosen_one = rm_cluster[i];
//        /* update predecessor clusters */
//        update_cluster_pred(ccg, rm_cluster, i, pred_cluster);
//
//        return;
//    }
//}

//void MergeTree::update_cluster_pred(ClusterCausalGraph& ccg,
//        vector<set<int> >& rm_cluster, const int chosen_i,
//        set<set<int> >& pred_cluster) {
//    /* update predecessor clusters */
//    for (set<set<int> >::iterator i_set
//            = ccg.cluster_pred[rm_cluster[chosen_i]].begin();
//            i_set != ccg.cluster_pred[rm_cluster[chosen_i]].end();
//            i_set++) {
//        pred_cluster.insert(*i_set);
//    }
//    rm_cluster.erase(rm_cluster.begin() + chosen_i);
//}
}
