#include <cassert>
#include <algorithm>

#include "lifted_strips_task.h"
#include "../conj_query_eval/join_order_generator/join_order_generator.h"
#include "magic_enum.hpp"
#include "../heuristics/regr_add.h"

namespace HELP { 

//#include "../conj_query_eval/join_and_project.h"

//TODO: add to util for debugging: https://stackoverflow.com/questions/24441505/retrieving-the-type-of-auto-in-c11-without-executing-the-program

void LiftedStripsTask::print_stats(std::ostream &outs) {
    outs << "Initial state size: " << initial_state.transform_to_vec().size() << std::endl;
    outs << "Action amount: " << actions.size() << std::endl;
    outs << "Object amount: " << objects.size() << std::endl;
    outs << "Maximal predicate arity: " << get_max_predicate_arity() << std::endl;
    outs << "Maximal amount free variables in condition: " << get_max_var_amount() << std::endl;
    outs << "Predicate amount: " << predicates.size() << std::endl;
}

ll LiftedStripsTask::get_max_predicate_arity() { //TODO: can combine below
    ll res = 0;
    for (auto &p : predicates) {
        res = std::max(p.get_arity(), res);
    }

    return res;
}

ll LiftedStripsTask::get_max_var_amount() {
    ll res = 0;
    for (auto &a : actions) {
        res = std::max(a.get_arity(), res);
    }

    return res;
}

void LiftedStripsTask::dump_dl_repr(std::ostream &outs) {
    dump_dl_repr(initial_state, outs);
    // outs << std::endl;
    dump_dl_repr(actions, outs);
}

void LiftedStripsTask::dump_dl_repr(Actions &actions, std::ostream &outs) {
    for (auto &action : actions) {
        dump_dl_repr(action, outs);
        outs << std::endl;
    }
}

void LiftedStripsTask::dump_dl_repr(Action &action, std::ostream &outs) {
    assert(action.get_add().size() == 1);
    assert(action.get_del().size() == 0);
    current_print_action = &action;

    dump_dl_repr(action.get_add(), outs);
    outs << " :- ";
    dump_dl_repr(action.get_pre(), outs);
    outs << " [";
    outs << action.get_cost();
    outs << "].";
}

void LiftedStripsTask::dump_dl_repr(const Atom &atom, std::ostream &outs) {
    if (atom.is_negated()) {
        outs << "!";
    }
    outs << get_pred_name(atom.get_predicate()) << "(";
    dump_dl_repr(atom.get_args(), outs);
    outs << ")";
}

void LiftedStripsTask::dump_dl_repr(const ParameterOrObject &arg, std::ostream &outs) {
    if (arg.is_object()) {
        outs << get_obj_name(arg.get_index());
    } else {
        if (var_as_num) {
            dump_dl_repr(NumericalVar{arg.get_index()}, outs);
        } else {
            outs << get_var_name(arg.get_index());
        }
    }
}

const std::string &LiftedStripsTask::get_obj_name(ll id) {
    return objects[id].get_name();
}

const std::string &LiftedStripsTask::get_pred_name(ll id) {
    return predicates[id].get_name();
}

const std::string &LiftedStripsTask::get_var_name(ll id) { //TODO: this is very hacky, use printable wrapper instead
    return current_print_action->get_var_name(id);
}

void LiftedStripsTask::dump_dl_repr(StripsState &state, std::ostream &outs) {
    for (auto &atom : state.transform_to_vec()) {
        dump_dl_repr(atom, outs);
        outs << " [0]." << std::endl;
    }
}

void LiftedStripsTask::dump_dl_repr(const GroundAtom &atom, std::ostream &outs) { //TODO: combine with non grounded atom
    outs << get_pred_name(atom.get_predicate()) << "(";
    dump_dl_repr(atom.get_args(), outs); //TODO: this will degenerate to vector<ll> which currently is unqiue for obj, but seems very weird
    outs << ")";
}

void LiftedStripsTask::dump_dl_repr(const ObjectRef &obj, std::ostream &outs) {
    outs << get_obj_name(obj);
}

void LiftedStripsTask::print_predicates(std::ostream &outs) {
    for (auto &p : predicates) {
        outs << p.get_name(); //TODO: make this a friend function?
        outs << std::endl;
    }
}

void LiftedStripsTask::dump_dl_repr(const PrintableJoinOrderEl &el, std::ostream &outs) {
    dump_dl_repr(el.from, outs);
    outs << " -> ";
    dump_dl_repr(el.to, outs);
}

void LiftedStripsTask::dump_dl_repr(const PrintableAnnotatedJoinOrderEl &el, std::ostream &outs) {
    dump_dl_repr(el.join_order_el, outs);
    outs << "[";
    outs << "joined: ";
    dump_dl_repr(el.join_pars, outs);
    outs << "; tracked: ";
    dump_dl_repr(el.tracked_pars, outs);
    outs << "]";
}

void LiftedStripsTask::dump_dl_repr(const Parameter &el, std::ostream &outs) {
    outs << el;
}

Parameters Action::get_eff_pars() { //TODO: should be named add / del
    auto all_pars_s = all_pars(add);
    return {all_pars_s.begin(), all_pars_s.end()};
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::JoinOrderGraph::ExtNodeRepr &el, std::ostream &outs) {
    outs << "[" << "node id: ";
    dump_dl_repr(el.id, outs);
    outs << ", "; //TODO make this a "print seperator"

    QueryEval::JoinGraphNode *ptr = nullptr; // probably not nesc. as you could just pick any field of ptr
    if (el.id.is_init()) { //TODO: visitor or create classes
        auto &init = *el.ptr.init_node->initializer;
        //TODO: wrap in function
        outs << "node type: " << "PredicateInit, "; //TODO: combine this with all other types
        outs << get_pred_name(init.predicate) << "(";
        std::vector<NumericalVar> num_vars;
        for (ll var : init.pos_order) {
            num_vars.push_back({var});
        }
        dump_dl_repr(num_vars, outs);
        outs << ")";

        ptr = el.ptr.init_node;
    } else if (el.id.is_merge()) {
        dump_dl_repr(el.ptr.merge_node, outs);
        ptr = el.ptr.merge_node;
    } else {
        assert(el.id.is_reorder());
        dump_dl_repr(el.ptr.reorder_node, outs);
        ptr = el.ptr.reorder_node;
    }

    outs << ", ";
    dump_dl_repr(ptr, outs);

    outs << "]";
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::NodeId &el, std::ostream &outs) {
    outs << "{ type: " << magic_enum::enum_name(el.get_enum_val()) << ", id_num: " << el.id_num << " }";
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::MergeNode &el, std::ostream &outs) { //TODO: create str consts for all of this and make constants accesible to pyunit
    outs << "node type: " << "MergeNode"; //TODO: combine this with all other types
    outs << ", ";
    outs << "from_l: ";
    dump_dl_repr(el.left, outs);
    outs << ", ";
    outs << "from_r: ";
    dump_dl_repr(el.right, outs);
    outs << ", ";
    outs << "left_forward: {"; //TODO: maybe we should do something like bracketvec, for name: annotated
    dump_dl_repr(el.left_positions_forward, outs);
    outs << "}, ";
    outs << "right_forward: {"; //TODO: maybe we should do something like bracketvec, for name: annotated
    dump_dl_repr(el.right_positions_forward, outs);
    outs << "}, ";
    outs << "joined_positions: {"; //TODO: maybe we should do something like bracketvec, for name: annotated
    dump_dl_repr(el.joined_positions, outs);
    outs << "}";
}
void LiftedStripsTask::dump_dl_repr(const QueryEval::ReOrderNode &el, std::ostream &outs) { //TODO: create str consts for all of this and make constants accesible to pyunit
    outs << "node type: " << "ReorderNode"; //TODO: combine this with all other types
    outs << ", ";
    outs << "from: ";
    dump_dl_repr(el.from, outs);
    outs << ", ";
    outs << "re_order: {"; //TODO: maybe we should do something like bracketvec, for name: annotated
    dump_dl_repr(el.re_order, outs);
    outs << "}, ";
}

void LiftedStripsTask::dump_dl_repr(const NumericalVar &var, std::ostream &outs) {
    outs << "?" << var.var_num;
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::PositionForward &position_map, std::ostream &outs) { //TODO: there is probably some util to pretty print structs like this
    outs << "(from: " << position_map.from << ", to: " << position_map.to << ")";
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::PositionMerge &position_map, std::ostream &outs) { //TODO: there is probably some util to pretty print structs like this
    outs << "(from_left: " << position_map.from_left
         << ", from_right: " << position_map.from_right
         << ", to: " << position_map.to << ")";
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::JoinGraphNode &node, std::ostream &outs) {
    outs << "edges: {"; //TODO: maybe we should do something like bracketvec, for name: annotated
    dump_dl_repr(node.edges, outs);
    outs << "}";
}

void LiftedStripsTask::dump_dl_repr(const VarMap &var_map, std::ostream &outs) {
    outs << "{";
    dump_dl_repr(var_map.var_mapping, outs);
    outs << "}";
}

void LiftedStripsTask::dump_dl_repr(const VarMapping &var_map, std::ostream &outs) {
    outs << var_map.var; //TODO: also standard for string?
    outs << " -> ";
    outs << var_map.obj; //TODO: also standard for string?
}

//TODO: to proper pars
void LiftedStripsTask::query_eval_print(QueryEval::JoinOrderGraph &graph, QueryEval::TableCache &cache, Action *action, std::ostream &outs, bool show_tables) {
    if (show_tables) { //TODO: make this an option and integrate into pretty print pattern
        for (auto &node : graph.get_nodes()) {
            dump_dl_repr(node.id, outs);
            outs <<":" << std::endl;
            for (auto &row : cache.get_node_table(node.ptr.merge_node->arr_id)) { //TODO is access verbose?
                outs << "|";
                for (auto obj : row) {
                    outs << " " << get_obj_name(obj) << " |";
                }
                outs << std::endl;
            }
        }
    } else {
        ll request_id = graph.get_last_request();
        auto result = graph.get_result(request_id, cache);
        auto &var_map = result.first;
        auto &table = result.second;

        std::vector<VarMap> var_maps;
        for (auto &row: table) {
            VarMap _var_map;
            for (auto &m: var_map) {
                _var_map.var_mapping.push_back({action ? action->get_var_name(m.first) : ("?" +  std::to_string(m.first)), get_obj_name(row[m.second])}); // TODO combine ?i with NumericalVar, rm hacky
            }
            var_maps.push_back(_var_map);
        }

        dump_dl_repr(var_maps, outs);
    }
}

//TODO: bool template just for consistency?
void LiftedStripsTask::query_eval_print(std::ostream &outs, bool show_tables) { //TODO: combine with _print_join_orders
    DBInfo info = get_info();

    for (auto &action : actions) {
        dump_dl_repr(action, outs);
        outs << std::endl;
        if (!show_tables) {
            outs << "applicable actions: "; //TODO: appropriate name by template
        } else {
            outs << "table cache: "; //TODO: appropriate name by template
        }

        auto [join_order, query, result_pars] = create_annoted_join_order(action, info);
        QueryEval::JoinOrderGraph graph(query, join_order, info);
        DBInfo info = get_info();
        QueryEval::TableCache cache = graph.create_table_cache_simple(info);
        QueryEval::SimplePredInitializerWrapManager start_node_manager = graph.create_simple_start_node_manager();
        graph.evaluate(initial_state, cache, start_node_manager, info);

        query_eval_print(graph, cache, &action, outs, show_tables);

        outs << std::endl << std::endl;
    }
}

void LiftedStripsTask::combined_query_eval_print(std::ostream &outs) { // TODO combine with query_eval_print //TODO: is now a misnomer
    DBInfo info = get_info();
    QueryEval::JoinOrderGraph jog(info);
    std::map<ll, std::pair<ll, QueryEval::NodeId>> action_id_to_res;
    QueryEval::TableCache cache = jog.create_table_cache_simple(info);
    QueryEval::RegressiveStartNodeMangerStandard start_node_manager = jog.create_recursive_start_nodes_manager_standard(info);

    ll node_am_last_extension = 0;
    for (ll id = 0; id < actions.size(); id++) {
        auto &action = actions[id];

        auto [query, result_pars] = create_query(action, info); //TODO: remove result par return
        jog.add_new_query(query, info);
        jog.adjust_for_new_nodes(cache, node_am_last_extension);
        jog.adjust_for_new_nodes(start_node_manager, node_am_last_extension);
        node_am_last_extension = jog.get_node_arr_am();
        action_id_to_res.emplace(id, std::make_pair(jog.get_last_request(), jog.get_last_request_node()));
    }

    for (auto &m : action_id_to_res) {
        auto &action = actions[m.first];
        auto &[req_id, req_node] = m.second;

        jog.mark_for_exploration(req_node, start_node_manager);
        jog.evaluate(initial_state, cache, start_node_manager, info);

        auto [var_map, table] = jog.get_result(req_id, cache);

        //TODO (important): combine with above - seperate dump_dl_repr
        std::vector<VarMap> var_maps;
        for (auto &row: table) {
            VarMap _var_map;
            for (auto &m: var_map) {
                _var_map.var_mapping.push_back({action.get_var_name(m.first), get_obj_name(row[m.second])});
            }
            var_maps.push_back(_var_map);
        }

        dump_dl_repr(action, outs);
        outs << std::endl << "applicable actions: ";
        dump_dl_repr(var_maps, outs);
        outs << std::endl << std::endl;
    }
}

NormalizeReturn LiftedStripsTask::normalize() { //TODO: maybe we want to add some task properties like normalized ...
    one_goal_only();
    ll just_t_id = remove_nullary_predicates_and_empty_preconditions(); //TODO: I don't think this is needed anymore
    remove_constants_from_atoms();
    remove_equal_parameter(); //TODO: maybe handling this via projects would be the nicer way :)

    //TODO: remove atoms always true
    //TODO: remove unused pars
    //TODO: remove useless static predicates (like type@object)
    //TODO: make sure every parameter occurs in pre

    return {just_t_id};
}

ll LiftedStripsTask::remove_nullary_predicates_and_empty_preconditions() {
    ll just_t_id = predicates.size();
    predicates.push_back({"__just_true", 0});
    initial_state.change_size(predicates.size());
    std::vector<Atom> always_true{{false, just_t_id, {}}};

    for (auto &action : actions) {
        if (action.get_pre().empty()) {
            action.set_pre(always_true);
        }
    }

    for (auto &pred : predicates) {
        if (pred.is_nullary()) {
            pred.set_unary();
        }
    }

    for (auto &action : actions) {
        action.make_atoms_unary(just_t_id);
    }

    ll true_obj_id = objects.size();
    objects.push_back({"__true"});

    initial_state.remove_nullary(true_obj_id);
    remove_nullary(goal, true_obj_id);

    GroundAtom just_true{just_t_id, {true_obj_id}};
    initial_state.insert(just_true);

    return true_obj_id;
}


std::set<ll> LiftedStripsTask::compute_static_predicates() {
    std::set<ll> potentially_static;

    for (ll i = 0; i < predicates.size(); i++) {
        potentially_static.insert(i);
    }

    for (auto &act : actions) {
        for (auto &atom : act.get_add()) {
            potentially_static.erase(atom.get_predicate());
        }
        for (auto &atom : act.get_del()) {
            potentially_static.erase(atom.get_predicate());
        }
    }

    return potentially_static;
}

//TODO: its weird in how many places ll should be unsigned, change that
std::map<ll, ll> LiftedStripsTask::compute_duplicated_objects() { //TODO: could be vector
    std::map<ll, ll> result;

    for (ll i = 0; i < predicates.size(); i++) {
        ll _max = 0;

        for (ll pos = 0; pos < predicates[i].get_arity(); pos++) {
            std::vector<ll> obj_ocur(objects.size(), 0);
            for (auto &atom : initial_state.get_atoms(i)) {
                obj_ocur[atom.get_args()[pos]]++;
            }

            _max = std::max(_max, *std::max_element(std::begin(obj_ocur), std::end(obj_ocur)));
        }

        result.emplace(i, _max);
    }

    return result;
}

std::map<ll, ll> LiftedStripsTask::compute_init_table_size() { //TODO: could be vector
    std::map<ll, ll> result;

    for (ll i = 0; i < predicates.size(); i++) {
        result.emplace(i, initial_state.get_atoms(i).size());
    }

    return result;
}

void LiftedStripsTask::print_initial_add_value(std::ostream &outs) {
    outs << "Initial h-add value is: " << RegrAdd(*this).compute(initial_state) << std::endl << std::endl;
}

void LiftedStripsTask::dump_dl_repr(const QueryEval::InitCollection &init_col, std::ostream &outs) {
    std::vector<Atom> atoms;
    for (auto &[pred, vv] : init_col.predicate_init) {
        for (auto &args : vv) {
            std::vector<ParameterOrObject> params;
            for (auto arg : args) {
                params.emplace_back(true, arg);
            }
            atoms.emplace_back(pred, params);
        }
    }

    var_as_num = true; //TODO: hacky rm
    dump_dl_repr(atoms, outs);
    var_as_num = false;
}

void LiftedStripsTask::remove_constants_from_atoms() {
    std::unordered_map<ll, ll> constant_to_const_pred;
    ll p_id_count = predicates.size();
    for (auto &act : actions) {
        act.remove_constants_from_atoms(constant_to_const_pred, p_id_count);
    }

    std::vector<ll> const_in_p_order(p_id_count-predicates.size());
    for (auto &[const_id, pred_id] : constant_to_const_pred) {
        const_in_p_order[pred_id-predicates.size()] = const_id;
    }

    initial_state.change_size(p_id_count);

    for (ll const_id : const_in_p_order) {
        GroundAtom gr_atom(predicates.size(), {{const_id}});
        initial_state.insert(gr_atom);
        predicates.emplace_back("@obj_"+objects[const_id].get_name(), 1);
    }
}

void LiftedStripsTask::remove_equal_parameter() {
    predicates.emplace_back("==", 2);
    for (auto &act : actions) {
        act.remove_equalities(predicates.size()-1);
    }

    initial_state.change_size(predicates.size());
    for (ll i = 0; i < objects.size(); i++) {
        GroundAtom gr_atom(predicates.size()-1, {{i}, {i}});
        initial_state.insert(gr_atom);
    }
}

void LiftedStripsTask::one_goal_only() {
    actions.emplace_back(std::string("achieve_goal"),
            std::vector<std::string>{}, //TODO: par
            goal,
            std::vector<Atom>{Atom(predicates.size(), std::vector<ParameterOrObject>{})},
            std::vector<Atom>{},
            0);
    goal.clear(); //TODO reset mem
    goal.emplace_back(predicates.size(), std::vector<ParameterOrObject>{});
    predicates.emplace_back("@artificial_goal", 0);
}

void LiftedStripsTask::simple_datalog_transformation() {
    std::vector<Action> new_actions;

    for (auto &action : actions) {
        for (auto &add : action.get_add()) {
            new_actions.emplace_back(
                action.get_name() + std::string("TODO:extension"),
                action.get_var_names(), //TODO: par
                action.get_pre(),
                std::vector<Atom>{add},
                action.get_del(), //TODO: this is hacky, used to track static atoms that would not be static if only deleted
                action.get_cost()
                );
        }
    }

    actions = new_actions;
}

}