#include "regr_add.h"

namespace HELP {

// debugging options
// #define RADD_EXPLORATION_DEBUG_PRINT // print node currently explored, represented query
// #define RADD_EXPLORATION_DEBUG_PRINT_ORDER_PRINT // additionally print join order
// #define RADD_EXPLORATION_DEBUG_PRINT_CACHE_PRINT // additionally print cache
// #define RADD_EXPLORATION_DEBUG_PRINT_SUCCESSOR_AMOUNT // additionally print amount successor generated
// #define RADD_EXPLORATION_PRINT_GOAL_COST_AFTER_EXPLORE // print goal cost after each exploration
// #define RADD_EXPLORATION_STOP_AFTER 40 // 500 // stop iteration after x explored nodes
// #define RADD_EXPLORATION_STOP_AFTER_CALL 284 // stop iteration after x times heuristic computation called
// #define RADD_EXPLORATION_PRINT_MICRO_MACRO_GRAPH // print graph consisting of micro and macro nodes --> can be visualized using micro-macro-viz.py
// #define RADD_EXPLORATION_FULFILLED_DEBUG_PRINT // print if node was fulfilled
// #define RADD_EXPLORATION_OR_NODE_EXTENSION_DEBUG_PRINT // print the extensions of an or node
// #define RADD_EXPLORATION_OR_NODE_EXTENSION_VAR_MAP_DEBUG_PRINT // print the more details about extension: query before, achiever and var map
// #define RADD_EXPLORATION_EXTENSION_SPLIT_DEBUG_PRINT // print splits per extension
// #define RADD_EXPLORATION_DL_TASK_DEBUG_PRINT // prints DL task on init
// #define RADD_EXPLORATION_CHOSEN_ATOM_SUBST_DEBUG_PRINT // prints which atoms per query are substituted
// #define RADD_EXPLORATION_PRINT_CALLED_STATE_DEBUG_PRINT // print the received state for h-val computation
// #define RADD_EXPLORATION_PRINT_DETECTED_STATIC_PREDS_DEBUG_PRINT // print the received state for h-val computation
// #define RADD_EXPLORATION_SUPSET_CHECK_DEBUG_PRINT // print supset checks
// #define RADD_EXPLORATION_DEBUG_PRINT_EXTRACTED_FAIL_NODES // print extracted fail nodes
// #define RADD_EXPLORATION_PRINT_RESULT // print h val
// #define RADD_EXPLORATION_BACKWARDS_PASS_PRINT

// additional sanity checks
// #define RADD_EXPLORATION_FULFILLED_SANITY_CHECKS // checks that isolated static queries are not re-evaluated

// stat printouts
// #define RADD_EXPLORATION_PRINT_SKIPPED_NODES // prints the amount of nodes skipped after pop from the evaluation queue
// #define RADD_EXPLORATION_PRINT_EXPLORED_NODES_AMOUNT // prints the amount of nodes explored
// #define RADD_EXPLORATION_PRINT_MAX_EXPLORED_COST // prints the maximal cost popped in q
// #define RADD_EXPLORATION_PRINT_JOG_NODE_AMOUNT // prints the amount of jog nodes after extension
// #define RADD_EXPLORATION_PRINT_NODES_PER_H_VAL

// TODO: important determine if query is never achievable

// TODO: important add_query_no merge, just for comparison

// TODO: important add  and/or structure

using namespace QueryEval;

static bool DO_UPWARDS_SUBQUERY_PRUNING = false; //TODO: make option
static bool ENABLE_AND_OR_NODES = true; //TODO: make option
static bool ENABLE_REGRESSION_SLIMMING = true; //TODO: make option

RegrAdd::RegrAdd(LiftedStripsTask &task)
    : task(task),
    dbInfo(task.get_info()),
    jog(dbInfo),
    result_cache(jog.create_table_cache_simple(dbInfo)),  //TODO: maybe we want to reserve some space for less copies in the beginning
    start_node_manager(jog.create_recursive_start_nodes_manager_standard(dbInfo)),
    subquery_result_cache(jog.create_table_cache_simple(dbInfo)),
    subquery_node_manager(jog.create_recursive_start_nodes_manager_no_static(dbInfo)),
    goal(create_goal()) { //TODO: maybe we want to reserve some space for less copies in the beginning

    predicate_to_action.resize(task.get_predicate_amount());

    // TODO: wrap in func (create achiever)
    auto &actions = task.get_actions();
    action_to_condition.resize(actions.size());
    for (ll i = 0; i < actions.size(); i++) {
        auto &action = actions[i];

        auto pre = action.get_pre();
        auto &_add = action.get_add();
        assert(_add.size() == 1);
        auto &add = _add[0];
        //assert(action.get_del().empty());

#ifndef NDEBUG
        std::set<ParRef> res_par_set;
#endif
        std::vector<ParRef> res_pars;
        for (auto &arg : add.get_args()) {
            assert (arg.is_variable());
#ifndef NDEBUG
            res_par_set.insert(arg.get_index());
#endif
            res_pars.push_back(arg.get_index());
        }
#ifndef NDEBUG
        assert(res_par_set.size() == res_pars.size());
#endif

        //TODO: wrap below in function "normalize" --- also use in "add_new_query"
        QueryEval::Query pre_q(pre, res_pars, dbInfo);
        auto jo = create_join_order_and_annotate(pre_q);
        JoinOrderGraph tmp_jog(pre_q, jo, dbInfo);
        auto top_node = tmp_jog.get_last_request_node();

        std::vector<ll> new_result_pars;
        InitCollection init_collection;
        tmp_jog.collect_atoms(init_collection, top_node, new_result_pars);
        assert(new_result_pars.size() == res_pars.size());
#ifndef NDEBUG
        ll i_t = 0;
        for (auto var : new_result_pars) {
            assert(i_t++ == var);
        }
#endif

        //TODO: important: add constant/equality constraints to pre if present in add

        predicate_to_action[add.get_predicate()].push_back(i);
        action_to_condition[i] = init_collection;
    }

    // TODO: rm code duplication below
    jog.adjust_for_new_nodes(result_cache, node_amount_on_last_extension);
    jog.adjust_for_new_nodes(start_node_manager, node_amount_on_last_extension);
    jog.adjust_for_new_nodes(subquery_node_manager, sq_node_amount_on_last_extension);
    jog.adjust_for_new_nodes(subquery_result_cache, sq_node_amount_on_last_extension);
    node_amount_on_last_extension = jog.get_node_arr_am();
    sq_node_amount_on_last_extension = jog.get_node_arr_am();
    // TODO end

#ifdef RADD_EXPLORATION_DL_TASK_DEBUG_PRINT
    std::cout << "Datalog representation:" << std::endl;
    task.dump_dl_repr(std::cout);
#endif

#ifdef RADD_EXPLORATION_PRINT_DETECTED_STATIC_PREDS_DEBUG_PRINT
    std::cout << "static predicates: ";
    for (auto pred : dbInfo.static_predicates) {
        std::cout << task.get_pred_name(pred) << ", ";
    }
    std::cout << std::endl;
#endif
}

QueryEval::NodeId RegrAdd::retrieve_and_register_last_node() {
    auto node = jog.get_last_request_node();
    auto request = jog.get_last_request();
    node_to_request.emplace(node, request);
    return node;
}

static void collect_preds(std::unordered_set<ll> &res, InitCollection &init_col) {
    for (auto &entry : init_col.predicate_init) {
        if (entry.second.empty()) {
            res.insert(entry.first);
        }
    }
}

MarcroNodeId RegrAdd::construct(QueryEval::InitCollection &init_collection) {
    std::vector<QueryEval::InitCollection> split_up;
    if (ENABLE_AND_OR_NODES) {
        split_init_by_vars(split_up, init_collection);
    } else {
        split_up.push_back(init_collection);
    }

#ifdef RADD_EXPLORATION_EXTENSION_SPLIT_DEBUG_PRINT
    std::cout << "splits are: " << "(under current or node count " << micro_nodes.size() << ")" << std::endl;
    for (auto &split : split_up) {
        std::cout << "split: ";
        task.dump_dl_repr(split, std::cout);
        std::cout << std::endl;
    }
#endif

    std::vector<QueryEval::NodeId> node_collection;

    for (auto &init_col : split_up) {
        std::vector<ll> no_result_pars;
        std::unordered_map<ll,ll> no_remap;
        jog.add_new_query(init_col, no_result_pars, no_remap, dbInfo);
#ifdef RADD_EXPLORATION_PRINT_JOG_NODE_AMOUNT
        std::cout << "JOG node amount after query add: " << jog.get_node_arr_am() << std::endl;
#endif
        auto last_node = retrieve_and_register_last_node();
#ifndef NDEBUG
        // nesc. check: validate predicates contained are the same
        InitCollection atoms_underlying;
        std::vector<ll> ignore;
        jog.collect_atoms(atoms_underlying, QueryEval::NO_NODE, ignore, &last_node);

        std::unordered_set<ll> split_preds;
        collect_preds(split_preds, init_col);
        std::unordered_set<ll> new_preds;
        collect_preds(new_preds, atoms_underlying);
        assert(split_preds == new_preds);
#endif
        node_collection.push_back(last_node);
    }

    return construct(node_collection);
}

void RegrAdd::split_init_by_vars(std::vector<QueryEval::InitCollection> &split_up_result,
                                 QueryEval::InitCollection &init_collection) {
    std::unordered_map<ParRef, size_t> map_var_to_cache;
    std::vector<std::set<ParRef>> hacky_set_cache; //TODO (important): set initial size of vec //TODO: use unordered set?

    // build var to match
    for (auto &[p, col] : init_collection.predicate_init) {
        for (auto &vars : col) {
            std::set<size_t> matches;

            for (auto var: vars) {
                if (map_var_to_cache.contains(var)) matches.insert(map_var_to_cache.at(var));
            }

            if (matches.size() == 0) {
                hacky_set_cache.emplace_back();
                for (auto var: vars) {
                    hacky_set_cache.back().insert(var);
                }
                matches.insert(hacky_set_cache.size()-1);
            }

            if (matches.size() == 1) {
                size_t match = *matches.begin();
                for (auto var: vars) {
                    map_var_to_cache[var] = match;
                    hacky_set_cache[match].insert(var);
                }
            } else {
                hacky_set_cache.emplace_back();
                size_t match = hacky_set_cache.size()-1;

                for (auto pre_v_match : matches) {
                    for (auto var : hacky_set_cache[pre_v_match]) {
                        map_var_to_cache[var] = match;
                        hacky_set_cache.back().insert(var);
                    }
                }
                for (auto var: vars) {
                    map_var_to_cache[var] = match;
                    hacky_set_cache[match].insert(var);
                }
            }
        }
    }

    // split up
    std::unordered_map<size_t, size_t> match_to_init_collection;
    for (auto &[p, col] : init_collection.predicate_init) {
        for (auto &vars: col) {
            assert(!vars.empty()); // should be guaranteed by task transformation, if not could also create different init collections

            auto match = map_var_to_cache[*vars.begin()];
            if (!match_to_init_collection.contains(match)) {
                match_to_init_collection.emplace(match, split_up_result.size());
                split_up_result.emplace_back();
            }

            auto &init_col_match = split_up_result.at(match_to_init_collection.at(match)).predicate_init;
            if (!init_col_match.contains(p)) {
                init_col_match.emplace(p, std::vector<std::vector<ll>>());
            }

            init_col_match.at(p).push_back(vars);
        }
    }
}

MarcroNodeId RegrAdd::construct(std::vector<QueryEval::NodeId> &node_collection) {
    std::sort(node_collection.begin(), node_collection.end());
    QueryEval::NodeId last = QueryEval::NO_NODE;
    std::vector<QueryEval::NodeId> new_node_collection;
    for (auto &node : node_collection) {
        if (last != node) new_node_collection.push_back(node);
        last = node;
        //TODO: important adjust for dropout rate; (cost x(dropout_am+1))
    }

    if (jog_to_macro.contains(node_collection)) {
        return jog_to_macro.at(node_collection);
    }

    macro_nodes.emplace_back();
    macro_node_explore_registry.notify_new_node();
    auto &node = macro_nodes.back();
    MarcroNodeId id(macro_nodes.size()-1);
    jog_to_macro.emplace(node_collection, id);

    for (auto node_id : node_collection) {
        node.groups_together.emplace(construct(node_id, id), node.groups_together.size());
    }
    for (auto &[node_id, _] : node.groups_together) {
        get(node_id)->reached_from.push_back(id);
    }

    reset(node);

    return id;
}

MarcroNodeId RegrAdd::create_goal() {
    Parameters no_res_pars;
    std::vector<Atom> goal_atoms = task.get_goal();
    assert(goal_atoms.size() == 1);
    auto &atom = goal_atoms[0];
    assert(atom.get_args().size() == 1);
    //TODO: assert goal achievable
    goal_atoms = {Atom(atom.get_predicate(), {ParameterOrObject(true, 0)})};

    QueryEval::Query goal_query(goal_atoms, no_res_pars, dbInfo);
    jog.add_new_query(goal_query, dbInfo);
    auto goal_node_id = retrieve_and_register_last_node();
    std::vector<QueryEval::NodeId> artf_goal_v{goal_node_id};
    return construct(artf_goal_v);
}

ll RegrAdd::combine_cost(ll cost, ll action) {
    return cost + task.get_actions()[action].get_cost(); //TODO: factor in replacement multiplicator
}

#ifdef RADD_EXPLORATION_STOP_AFTER_CALL
    static ll local_call_count = 0;
#endif

template <class Q>
void clear_queue(Q &q) { //TODO: move to util
    q = Q();
}

ll RegrAdd::compute(StripsState &state) {
    result_cache.reset(); //TODO important: use delta computation instead

#ifdef RADD_EXPLORATION_STOP_AFTER_CALL
    std::cout << "==== CALLED ====" << std::endl;
#endif

#ifdef RADD_EXPLORATION_PRINT_CALLED_STATE_DEBUG_PRINT
    std::cout << "state is:" << std::endl;
    task.dump_dl_repr(state, std::cout);
#endif

    start_node_manager.reset();
    clear_queue(q);
    q.push({0, CommonNodeId(goal)});
    for (auto node : seen_nodes) {
        if (node.is_macro()) { //TODO: make visitor
            reset(*get(node.get_node_as_macro()));
        } else {
            reset(*get(node.get_node_as_micro()));
        }
    }
    seen_nodes.clear();
    seen_nodes.emplace_back(goal);
    get(goal)->common.current_cheapest_explore_cost = 0;

    assert(get_cost(goal) == std::numeric_limits<ll>::max());
#ifdef RADD_EXPLORATION_STOP_AFTER
    ll iteration = 0;
#endif

#ifdef RADD_EXPLORATION_PRINT_SKIPPED_NODES
    ll skip_count = 0;
#endif
#ifdef RADD_EXPLORATION_PRINT_EXPLORED_NODES_AMOUNT
    ll pop_amount = 0;
#endif
#ifdef RADD_EXPLORATION_PRINT_MAX_EXPLORED_COST
    ll max_pop_amount = 0;
#endif

#ifdef RADD_EXPLORATION_PRINT_NODES_PER_H_VAL
    ll last_layer = 0;
#endif

    assert(q.size() == 1);
    while (!q.empty()) {
        auto [cost, expl_node] = q.top();
        q.pop();

#ifdef RADD_EXPLORATION_PRINT_EXPLORED_NODES_AMOUNT
        pop_amount++;
#endif
#ifdef RADD_EXPLORATION_PRINT_MAX_EXPLORED_COST
        max_pop_amount = std::max(max_pop_amount, cost);
#endif
#ifdef RADD_EXPLORATION_PRINT_NODES_PER_H_VAL
        if (cost > last_layer) {
            last_layer = cost;
            std::cout << "New exploration bound " << cost << ", amount of or nodes: " << micro_nodes.size()<< ", amount of and nodes: " << macro_nodes.size() << std::endl;
        }
#endif

#ifdef RADD_EXPLORATION_DEBUG_PRINT
        std::cout << "Popped cost: " << cost << " with explore cost bound of: " << get_common(expl_node).current_explore_cost_bound << std::endl;
#endif

        if (cost >= get_cost(goal)) {
            break;
        }

        assert(valid_node(expl_node));
        if (get_common(expl_node).current_cheapest_explore_cost == std::numeric_limits<ll>::max()) {
            assert(get_common(expl_node).current_explore_cost_bound < std::numeric_limits<ll>::max());
            continue;
        }
        if (expl_node.is_macro()) { //TODO: make visitor
            if (can_skip(expl_node.get_node_as_macro(), cost)) {
#ifdef RADD_EXPLORATION_PRINT_SKIPPED_NODES
                skip_count++;
#endif
                potential_backwards_pass(expl_node.get_node_as_macro());
                continue;
            }
            explore(expl_node.get_node_as_macro());
        } else {
            if (can_skip(expl_node.get_node_as_micro(), cost)) {
#ifdef RADD_EXPLORATION_PRINT_SKIPPED_NODES
                skip_count++;
#endif
                potential_backwards_pass(expl_node.get_node_as_micro());
                continue;
            }
            explore(expl_node.get_node_as_micro(), state);
        }
        get_common(expl_node)._was_explored = true;

#ifdef RADD_EXPLORATION_PRINT_GOAL_COST_AFTER_EXPLORE
        std::cout << "Goal cost is: " << get_cost(goal) << std::endl;
#endif
#ifdef RADD_EXPLORATION_STOP_AFTER
        if (iteration++ == RADD_EXPLORATION_STOP_AFTER) {
            break;
        }
#endif
    }

#ifdef RADD_EXPLORATION_PRINT_MICRO_MACRO_GRAPH
    std::cout << std::endl << std::endl;
    std::cout << "========" << std::endl;
    std::cout << "Micro/Macro graph:" << std::endl;
    std::cout << "========" << std::endl;
    ll i = 0;
    for (auto &node : macro_nodes) {
        std::cout << "And node" << i++ << ": " << std::endl;
        std::cout << "groups together:" << std::endl;
        for (auto &[node, index] : node.groups_together) {
            std::cout << node.get() << ", ";
        }
        std::cout << std::endl << std::endl;
    }
    i = 0;
    for (auto &node : micro_nodes) {
        std::cout << "Or node" << i++<< ": ";
        std::cout << "representing ";
        InitCollection atoms_underlying;
        std::vector<ll> ignore;
        jog.collect_atoms(atoms_underlying, QueryEval::NO_NODE, ignore, &node.assoiciated_query);
        task.dump_dl_repr(atoms_underlying, std::cout);
        std::cout << " cost " << node.common.fulfilled_cost << std::endl;
        std::cout << "is extended to:" << std::endl;
        for (auto &[_, v1] : node.substitutes_to) {
            for (auto &v2 : v1) {
                for (auto &s_link : v2) {
                    std::cout << s_link << ", ";
                }
            }
        }
        std::cout << std::endl << std::endl;
    }
#endif

#ifdef RADD_EXPLORATION_PRINT_SKIPPED_NODES
    std::cout << "Amount of nodes skipped after pop from eval-q: " << skip_count << std::endl;
#endif

#ifdef RADD_EXPLORATION_PRINT_EXPLORED_NODES_AMOUNT
    std::cout << "Amount of nodes popped in regr-add eval: " << pop_amount << std::endl;
#endif

#ifdef RADD_EXPLORATION_PRINT_MAX_EXPLORED_COST
    std::cout << "Max cost popped during regr-add eval: " << max_pop_amount << std::endl;
#endif


#ifdef RADD_EXPLORATION_PRINT_RESULT
    std::cout << "regr add result is: " << get_cost(goal) << std::endl;
#endif

#ifdef RADD_EXPLORATION_STOP_AFTER_CALL
    if (++local_call_count == RADD_EXPLORATION_STOP_AFTER_CALL) {
        std::cout << "ending" << std::endl;
        exit(0);
    }
#endif

    return get_cost(goal);
}

void RegrAdd::register_q_exploral(MarcroNodeId macro, MircroNodeId micro) { //TODO: combine with function below
    auto *node = get(macro);
    auto *to = get(micro);
    q.push({node->common.current_cheapest_explore_cost, micro});
    seen_nodes.emplace_back(micro);
#ifdef RADD_EXPLORATION_BOUND_CHANGE
    std::cout << "Explore cost bound for micro node " << micro.get() << " is: " << get(micro)->common.current_explore_cost_bound << " (updated by q exploral)" << std::endl;
#endif
}

void RegrAdd::register_q_exploral(MircroNodeId micro, MarcroNodeId macro, ll act_id) {
    auto *node = get(micro);
    auto *to = get(macro);
    q.push({node->common.current_cheapest_explore_cost, macro});
    seen_nodes.emplace_back(macro);
#ifdef RADD_EXPLORATION_BOUND_CHANGE
    std::cout << "Explore cost bound for macro node " << macro.get() << " is: " << get(macro)->common.current_explore_cost_bound << " (updated by q exploral)" << std::endl;
#endif
}

void RegrAdd::explore(MarcroNodeId node_id) {
#ifdef RADD_EXPLORATION_DEBUG_PRINT
    std::cout << "Exploring and node: " << node_id.get() <<  std::endl;
#endif
    auto *node = get(node_id);
    if (node->current_cheapest_reached != NO_MICRO_NODE) node->common.current_explore_cost_bound = get(node->current_cheapest_reached)->common.current_explore_cost_bound; //TODO: seems hacky, find proper place to do this

    for (auto &[micro_id, act_id] : node->groups_together) {
        auto *to = get(micro_id);
        to->reached_this_iteration.insert(node_id);
        to->reached_this_iteration_todo.insert(node_id);
        if (try_set_cheaper_cost(node_id, micro_id)) {
                register_q_exploral(node_id, micro_id);
        } else {
            potential_backwards_pass(micro_id);
        }
    }
}

void RegrAdd::explore(MircroNodeId node_id, StripsState &state) {
#ifdef RADD_EXPLORATION_DEBUG_PRINT
    std::cout << "Exploring or node: " << node_id.get() <<  std::endl;
#endif

    auto *node = get(node_id);
    node->common.current_explore_cost_bound = get(node->current_cheapest_reached)->common.current_explore_cost_bound; //TODO: seems hacky, find proper place to do this
    if (fulfilled(get(node_id)->assoiciated_query, state)) {
#ifdef RADD_EXPLORATION_FULFILLED_DEBUG_PRINT
        std::cout << "==> Fulfilled. (or node: " << node_id.get() << ")" << std::endl;
#endif
        backwards_pass(node_id, 0);
    } else {
        get_next_nodes(node_id);
    }
}

void RegrAdd::node_state_eval_print(QueryEval::NodeId node_id, StripsState &state) {
#ifdef RADD_EXPLORATION_DEBUG_PRINT //TODO wrap in func
    std::cout << "Exploring node ";
    task.dump_dl_repr(node_id, std::cout);
    std::cout << std::endl;
    std::cout << "Representing query ";
    //TODO: wrap below in function
    InitCollection atoms_underlying;
    std::vector<ll> ignore;
    jog.collect_atoms(atoms_underlying, QueryEval::NO_NODE, ignore, &node_id);
    task.dump_dl_repr(atoms_underlying, std::cout);
#ifdef RADD_EXPLORATION_DEBUG_PRINT_ORDER_PRINT
    std::cout << std::endl << "order: ";
    task.dump_dl_repr(jog.get_nodes(), std::cout);
#endif
#ifdef RADD_EXPLORATION_DEBUG_PRINT_CACHE_PRINT
    std::cout << std::endl << "Table cache is: " << std::endl;
    task.query_eval_print(jog, result_cache, nullptr, std::cout, true);
#endif
    std::cout << std::endl << std::endl;
#endif
}

bool RegrAdd::fulfilled(QueryEval::NodeId node, StripsState &state) {
    jog.mark_for_exploration(node, start_node_manager);
    jog.evaluate(state, result_cache, start_node_manager, dbInfo); // TODO: verbose
    node_state_eval_print(node, state);
    bool result = !jog.get_result(node_to_request[node], result_cache).second.empty(); //TODO: wrap in function
#ifdef RADD_EXPLORATION_FULFILLED_SANITY_CHECKS
    InitCollection atoms_underlying;
    std::vector<ll> ignore;
    jog.collect_atoms(atoms_underlying, QueryEval::NO_NODE, ignore, &node);
    bool all_static = true;
    for (auto &[p, _] : atoms_underlying.predicate_init) {
        if (!dbInfo.static_predicates.contains(p)) {
                all_static = false;
                break;
        }
    }
    if (all_static) assert(result);
#endif
    return result;
}

ll get_max_var(InitCollection &init_collection) {
    ll _max = 0;
    for (auto &m : init_collection.predicate_init) {
        for (auto v : m.second) {
            for (auto var : v) {
                _max = std::max(_max, var);
            }
        }
    }

    return _max;
}

bool use_all_next_nodes = false; //TODO: make an actual option

static void reduce_init_collection(InitCollection &init_collection, DBInfo &db_info) {
    for (auto &[pred, table] : init_collection.predicate_init) {
        if (db_info.static_predicates.contains(pred)) {
            // TODO: could just use sort_unique instead
            std::vector<std::vector<ll>> new_table;
            std::sort(table.begin(), table.end());
            std::vector<ll> last = {-1}; //TODO: rm hacky
            for (auto &row : table) {
                if (row != last) {
                    new_table.push_back(row);
                }
                last = row;
            }
            table = new_table;
        }
    }
}

std::vector<MarcroNodeId> &RegrAdd::generate_successor(MircroNodeId node_id, ll predicate, ll i) {
    if (get(node_id)->substitution_computed.at(predicate).at(i)) {
        return get(node_id)->substitutes_to.at(predicate).at(i);
    }

#ifdef RADD_EXPLORATION_OR_NODE_EXTENSION_DEBUG_PRINT
    std::cout << "Extending Or node " << node_id.get() << std::endl;
#endif

    for (auto act: predicate_to_action[predicate]) {
        auto achiever = action_to_condition[act];

        // should be here to avoid segfaults after node reallocation
        auto &atoms_underlying = get(node_id)->query_representation;
        ll base_max = get(node_id)->max_var;
        auto &vec = atoms_underlying.predicate_init.at(predicate);
        auto &args = vec[i];

        std::map<ll, ll> var_map; //TODO (important): should be vector
        create_var_map(args, achiever, base_max, var_map);

#ifdef RADD_EXPLORATION_OR_NODE_EXTENSION_VAR_MAP_DEBUG_PRINT
        /*std::cout << "query representation before: ";
        task.dump_dl_repr(atoms_underlying, std::cout);
        std::cout << std::endl;*/

        std::cout << "achiever: ";
        task.dump_dl_repr(achiever, std::cout);
        std::cout << std::endl;

        std::cout << "map is: ";
        for (auto &[v1, v2]: var_map) {
            std::cout << v1 << " -> " << v2 << ", ";
        }
        std::cout << std::endl;
#endif

        InitCollection new_condition;
        extend_condition(atoms_underlying, achiever, new_condition, predicate, i, var_map);
        reduce_init_collection(new_condition, dbInfo);

#ifdef RADD_EXPLORATION_OR_NODE_EXTENSION_DEBUG_PRINT
        std::cout << "Extension is: ";
        task.dump_dl_repr(new_condition, std::cout);
        std::cout << std::endl;
#endif
        if (DO_UPWARDS_SUBQUERY_PRUNING && can_prune(new_condition, node_id)) { //TODO: important: maybe just go upward up to range?
            continue;
        }

        MarcroNodeId macro_node = construct(new_condition);
        get(node_id)->substitutes_to.at(predicate).at(i).push_back(macro_node);
        get(macro_node)->reached_from.emplace_back(node_id, act); //TODO (important): could lead to duplicates

        auto &annotation_map = get(node_id)->action_annoation;

        bool should_add_annotation = false;
        if (!annotation_map.contains(macro_node)) {
            should_add_annotation = true;
        } else if (task.get_actions()[act].get_cost() < task.get_actions()[annotation_map.at(macro_node)].get_cost()) { //TODO: wrap cost retrival into func get action cost
            should_add_annotation = true;
            annotation_map.erase(macro_node);
        }
        if (should_add_annotation) {
            annotation_map.emplace(macro_node, act);
        }
    }

    get(node_id)->substitution_computed[predicate][i] = true;
    return get(node_id)->substitutes_to.at(predicate).at(i);
}

void RegrAdd::add_successor_to_q(MircroNodeId node_id, ll predicate, ll i) { // TODO: rn: no longer adds to queue but only generates
    assert(node_id != NO_MICRO_NODE);
    auto &successors = generate_successor(node_id, predicate, i);
    for (auto macro : successors) {
        get(node_id)->substitued_to_this_iteration.push_back(macro);
    }

#ifdef RADD_EXPLORATION_CHOSEN_ATOM_SUBST_DEBUG_PRINT
    std::cout << "In Or node" << node_id.get() << ": ";
    std::cout << "representing ";
    InitCollection atoms_underlying;
    std::vector<ll> ignore;
    jog.collect_atoms(atoms_underlying, QueryEval::NO_NODE, ignore, &get(node_id)->assoiciated_query);
    task.dump_dl_repr(atoms_underlying, std::cout);
    std::cout << std::endl;
    std::cout << "the following atom is replaced for further exploration: ";
    std::cout << task.get_pred_name(predicate) << "(";
    auto s = get(node_id)->query_representation.predicate_init.at(predicate).at(i).size();
    for (auto var : get(node_id)->query_representation.predicate_init.at(predicate).at(i)) {
        std::cout << "?" << var << (--s != 0 ? ", " : "");
    }
    std::cout << ")" << std::endl;
    std::cout << "Amount of generated successors: " << successors.size() << std::endl;
    ll successor_count = 0;
#endif
}

void filter_fail_nodes(std::unordered_map<QueryEval::NodeId, std::vector<std::pair<ll, ll>>> &jog_query_repr_mapping,
                       std::unordered_map<QueryEval::NodeId, std::vector<ll>> &failing_nodes,
                       std::vector<QueryEval::NodeId> &filtered_fail_nodes) {
    ll max_annotation = -1;
    for (auto &[node, annotation_vec] : failing_nodes) {
        for (auto annotation : annotation_vec) {
            max_annotation = std::max(max_annotation, annotation);
        }
    }

    std::vector<ll> fail_annotation_ams(max_annotation+1, 0);
    for (auto &[node, annotation_vec] : failing_nodes) {
        for (auto annotation : annotation_vec) {
            if (annotation != -1) fail_annotation_ams[annotation] += jog_query_repr_mapping.at(node).size();
        }
    }

    ll current_min = std::numeric_limits<ll>::max();
    ll current_min_id = -1;
    for (ll i = 0; i < fail_annotation_ams.size(); i++) {
        if (fail_annotation_ams[i] < current_min) {
            current_min = fail_annotation_ams[i];
            current_min_id = i;
        }
    }

    if (current_min_id != -1) {
        for (auto &[node, annotation_vec] : failing_nodes) {
            for (auto annotation : annotation_vec) {
                if (current_min_id == annotation) {
                    filtered_fail_nodes.push_back(node);
                }
            }
        }
    }
}

void RegrAdd::get_next_nodes(MircroNodeId node_id) {
    auto *node = get(node_id);

    if (!node->substitued_to_this_iteration_created) {
        std::unordered_map<QueryEval::NodeId, std::vector<ll>> failing_nodes;
        jog.extract_fail_nodes(result_cache, node->assoiciated_query, failing_nodes);

#ifdef RADD_EXPLORATION_DEBUG_PRINT_EXTRACTED_FAIL_NODES
        std::cout << "Extracted Fail Predicates are: ";
        for (auto &[node, v] : failing_nodes) {
            assert(node.is_init());
            std::cout
                << task.get_pred_name(
                       static_cast<InitNode &>(jog.pub_get(node)).initializer.get()->predicate)
                << (jog.pub_get(node).is_static ? " (static" : " (non-static")
                << ", id: " << std::to_string(node.id_num) << "), ";
        }
        std::cout << std::endl;
#endif

        std::vector<QueryEval::NodeId> filtered_fail_nodes;
        if (ENABLE_REGRESSION_SLIMMING) {
            filter_fail_nodes(node->jog_query_repr_mapping, failing_nodes, filtered_fail_nodes);
        } else {
            for (auto &[n, _] : failing_nodes) {
                filtered_fail_nodes.push_back(n);
            }
        }

#ifdef RADD_EXPLORATION_DEBUG_PRINT_EXTRACTED_FAIL_NODES
        std::cout << "Filtered Fail Predicates are: ";
        for (auto node : filtered_fail_nodes) {
            assert(node.is_init());
            std::cout
                << task.get_pred_name(
                       static_cast<InitNode &>(jog.pub_get(node)).initializer.get()->predicate)
                << ", ";
        }
        std::cout << std::endl;
#endif

#ifdef RADD_EXPLORATION_DEBUG_PRINT_SUCCESSOR_AMOUNT
        std::cout << "Amount of atoms to replace: " << p_i_subst_combinations.size() << std::endl;
        ll created_am = 0;
#endif

        for (auto _node_id : filtered_fail_nodes) {
            for (auto &[predicate, i] : get(node_id)->jog_query_repr_mapping[_node_id]) {
                add_successor_to_q(node_id, predicate, i);
#ifdef RADD_EXPLORATION_DEBUG_PRINT_SUCCESSOR_AMOUNT
                created_am++;
#endif
            }
        }

#ifdef RADD_EXPLORATION_DEBUG_PRINT_SUCCESSOR_AMOUNT
        std::cout << "Amount of new queries created: " << created_am << std::endl;
#endif

        jog.adjust_for_new_nodes(result_cache, node_amount_on_last_extension);
        jog.adjust_for_new_nodes(start_node_manager, node_amount_on_last_extension);
        jog.adjust_for_new_nodes(subquery_node_manager, sq_node_amount_on_last_extension);
        jog.adjust_for_new_nodes(subquery_result_cache, sq_node_amount_on_last_extension);
        node_amount_on_last_extension = jog.get_node_arr_am();
        sq_node_amount_on_last_extension = jog.get_node_arr_am();
        get(node_id)->substitued_to_this_iteration_created = true;
    }

    for (auto macro_id : get(node_id)->substitued_to_this_iteration) {
        auto *node = get(macro_id);
        ll act = get(node_id)->action_annoation.at(macro_id);
        node->reached_this_iteration.emplace(node_id, act);
        node->reached_this_iteration_todo.emplace(node_id, act);
#ifdef RADD_EXPLORATION_CHOSEN_ATOM_SUBST_DEBUG_PRINT
        std::cout << "successor " << ++successor_count << " is and node " << subts_link.to.get() << ", grouping or nodes: ";
        ll x = 0;
        for (auto &[node,_] : get(subts_link.to)->groups_together) {
            std::cout << node.get();
            if (++x != get(subts_link.to)->groups_together.size()) std::cout << ", ";
        }
        std::cout << std::endl;
#endif
        if (!was_explored(node->common) && try_set_cheaper_cost(node_id, act, macro_id)) {
            register_q_exploral(node_id, macro_id, act);
        } else {
            potential_backwards_pass(macro_id);
        }
    }

#ifdef RADD_EXPLORATION_CHOSEN_ATOM_SUBST_DEBUG_PRINT
    std::cout << "===== end successor list ========" << std::endl;
#endif
}

MircroNodeId RegrAdd::construct(QueryEval::NodeId node_id, MarcroNodeId parent) {
    if (jog.arr_get(node_id) >= jog_to_micro.size()) {
        jog_to_micro.resize(std::max(static_cast<size_t>(jog.arr_get(node_id)+1), 2*(jog_to_micro.size())), NO_MICRO_NODE);
    }
    assert(jog_to_micro.size() > jog.arr_get(node_id));
    auto &arr_cont = jog_to_micro[jog.arr_get(node_id)];
    if (arr_cont != NO_MICRO_NODE) {
        return arr_cont;
    }

    micro_nodes.emplace_back();
    micro_node_explore_registry.notify_new_node();
    auto &micro_node = micro_nodes.back();
    reset(micro_node);
    micro_node.parent = parent;
    micro_node.assoiciated_query = node_id;

    std::vector<ll> ignore; //TODO: hacky, make opt in function
    jog.collect_atoms(micro_node.query_representation,
                      QueryEval::NO_NODE,
                      ignore,
                      &node_id,
                      nullptr,
                      nullptr,
                      &micro_node.jog_query_repr_mapping);

    for (const auto &[p, vec] : micro_node.query_representation.predicate_init) {
        micro_node.substitutes_to.emplace(p, std::vector<std::vector<MarcroNodeId>>(vec.size()));
        micro_node.substitution_computed.emplace(p, std::vector<bool>(vec.size(), false));
    }

    micro_node.max_var = get_max_var(micro_node.query_representation);

    arr_cont = MircroNodeId(micro_nodes.size()-1);
    return arr_cont;
}

void RegrAdd::extend_condition(QueryEval::InitCollection &original, QueryEval::InitCollection &extension,
                               QueryEval::InitCollection &result, ll p_to_replace, ll p_args_pos_to_replace,
                               std::map<ll, ll> &extension_var_map) {
    result.predicate_init = original.predicate_init;
    auto p_modify = original.predicate_init.at(p_to_replace);
    std::vector<std::vector<ll>> p_modification;
    for (ll i = 0; i < p_modify.size(); i++) {
        if (i != p_args_pos_to_replace) {
            p_modification.push_back(p_modify[i]);
        }
    }

    result.predicate_init.at(p_to_replace) = p_modification;

    for (auto &[p, args_vec] : extension.predicate_init) {
        if (!result.predicate_init.contains(p)) {
            result.predicate_init.emplace(p, std::vector<std::vector<ll>>());
        }

        for (auto &args : args_vec) {
            std::vector<ll> new_args;
            for (auto arg : args) {
                assert(extension_var_map.contains(arg));
                new_args.push_back(extension_var_map[arg]);
            }
            result.predicate_init.at(p).push_back(new_args);
        }
    }

    if (result.predicate_init.at(p_to_replace).empty()) {
        result.predicate_init.erase(p_to_replace);
    }
}

void RegrAdd::create_var_map(std::vector<ll> &new_add_args, QueryEval::InitCollection &achiever, ll base_max,
                             std::map<ll, ll> &result) {
    for (ll i = 0; i < new_add_args.size(); i++) {
        result.emplace(i, new_add_args[i]);
    }
    ll achiever_max = get_max_var(achiever);
    base_max++;
    for (ll i = new_add_args.size(); i <= achiever_max; i++) {
        result.emplace(i, base_max++);
    }
}

bool RegrAdd::try_set_cheaper_cost(MircroNodeId from_id, ll action_id, MarcroNodeId to_id) {
    auto *from = get(from_id);
    auto *to = get(to_id);
    auto new_cost = from->common.current_cheapest_explore_cost + task.get_actions()[action_id].get_cost(); //TODO: amount of subst

    macro_node_explore_registry.mark_potential_later_explore(from_id, to_id, new_cost);

    if (new_cost >= to->common.current_cheapest_explore_cost) {
        return false;
    }

    to->common.current_cheapest_explore_cost = new_cost;
    to->current_cheapest_reached = from_id;
    to->common.current_explore_cost_bound = get(from_id)->common.current_explore_cost_bound;

    return true;
}

void RegrAdd::reset(CommonNodeContent &node) {
    node.do_not_explore = false;
    node.current_explore_cost_bound = std::numeric_limits<ll>::max(); //TODO: should depend on cost type
    node.current_cheapest_explore_cost = std::numeric_limits<ll>::max(); //TODO: should depend on cost type
    node.already_further_explored_cost = 0;
    node._was_explored = false;
}

void RegrAdd::reset(MacroNode &node) {
    reset(node.common);
    node.fulfilled_cost = std::vector<ll>(node.groups_together.size(), -1); //TODO: should depend on cost type
    node.unfulfilled_count = node.groups_together.size();
    node.common.fulfilled_cost = 0;
    macro_node_explore_registry.reset_q_manager(get_macro_id(node));
    node.current_cheapest_reached = NO_MICRO_NODE;
    node.reached_this_iteration.clear();
    node.reached_this_iteration_todo.clear();
}

void RegrAdd::reset(MicroNode &node) {
    reset(node.common);
    node.common.fulfilled_cost = std::numeric_limits<ll>::max();
    micro_node_explore_registry.reset_q_manager(get_micro_id(node));
    node.current_cheapest_reached = NO_MACRO_NODE;
    node.reached_this_iteration.clear();
    node.reached_this_iteration_todo.clear();
    node.substitued_to_this_iteration_created = false;
    node.substitued_to_this_iteration.clear();
}

void RegrAdd::backwards_pass(MarcroNodeId node_id, ll cost, MircroNodeId from)
{
    auto *node = get(node_id);

    if (from != NO_MICRO_NODE) {
#ifdef RADD_EXPLORATION_BACKWARDS_PASS_PRINT
    std::cout << "Macro node " << node_id.get() << " reached with cost " << cost
              << ", current unfulfullied count: " << node->unfulfilled_count << std::endl;
#endif

        auto f_id = node->groups_together[from];
        assert(f_id < node->fulfilled_cost.size());
        auto &current_cost = node->fulfilled_cost[f_id];
        if (current_cost == -1) {
            node->unfulfilled_count--;
            node->common.fulfilled_cost += cost;

#ifdef RADD_EXPLORATION_BACKWARDS_PASS_PRINT
        std::cout << "Unfulfilled count was updated to " << node->unfulfilled_count
                  << " in macro node " << node_id.get() << std::endl;
#endif
        }
        else if (current_cost <= cost) {
            return;
        }
        else {
            node->common.fulfilled_cost -= (current_cost - cost);
            if (!node->unfulfilled_count) {
                node->reached_this_iteration_todo = node->reached_this_iteration;
            }
        }

        current_cost = cost;
        assert(current_cost != -1);
        assert(node->unfulfilled_count <= node->groups_together.size());
    }

#ifndef NDEBUG
    ll minuscount = 0;
    for (ll i : node->fulfilled_cost) {
        if (i == -1) {
            minuscount++;
        }
    }
    assert(minuscount == node->unfulfilled_count);
#endif

    if (!node->unfulfilled_count) {
        assert(!node->reached_this_iteration.empty() || node_id == goal);
        for (auto &[reaching_node_id, reaching_action] : node->reached_this_iteration_todo) {
            backwards_pass(reaching_node_id, node->common.fulfilled_cost + task.get_actions()[reaching_action].get_cost());
        }
        node->reached_this_iteration_todo.clear();
    }
}

void RegrAdd::backwards_pass(MircroNodeId node_id, ll cost) {
    auto *node = get(node_id);

#ifdef RADD_EXPLORATION_BACKWARDS_PASS_PRINT
    std::cout << "Micro node " << node_id.get()  << " reached with cost " << cost << std::endl;
#endif

    if (node->common.fulfilled_cost < cost) {
        return;
    }
    if (node->common.fulfilled_cost == cost && node->reached_this_iteration_todo.empty()) {
        return;
    }

    if (node->common.fulfilled_cost != std::numeric_limits<ll>::max() && node->common.fulfilled_cost != cost) {
        node->reached_this_iteration_todo = node->reached_this_iteration;
    }

    node->common.fulfilled_cost = cost;

    assert(!node->reached_this_iteration.empty());
    for (auto reaching_node_id : node->reached_this_iteration_todo) {
        backwards_pass(reaching_node_id, cost, node_id);
    }
    node->reached_this_iteration_todo.clear();
    mark_minimal_achiever_cost(node_id, node->common.current_cheapest_explore_cost + node->common.fulfilled_cost);
}

bool RegrAdd::try_set_cheaper_cost(MarcroNodeId from_id, MircroNodeId to_id) {
    auto *from = get(from_id);
    auto *to = get(to_id);

    micro_node_explore_registry.mark_potential_later_explore(from_id, to_id, from->common.current_cheapest_explore_cost);

    if (from->common.current_cheapest_explore_cost >= to->common.current_cheapest_explore_cost) {
        return false;
    }

    to->common.current_cheapest_explore_cost = from->common.current_cheapest_explore_cost;
    to->current_cheapest_reached = from_id;
    to->common.current_explore_cost_bound = get(from_id)->common.current_explore_cost_bound;

    return true;
}

bool RegrAdd::is_supset(QueryEval::InitCollection &collection, NodeId node) {
    jog.mark_for_exploration(node, subquery_node_manager);
    jog.evaluate(collection, subquery_result_cache, subquery_node_manager, dbInfo);
    return !subquery_result_cache.get_node_table(jog.arr_get(node)).empty(); //TODO: wrap in function
}

MicroNode *RegrAdd::next_micro_parent(MicroNode *node) {
    assert(node);
    auto *parent = get(node->parent);

    if (parent->reached_from.size() != 1) {
        return nullptr;
    }
    return get(parent->reached_from.begin()->first);
}

// TODO: important should think about if it makes sense to also analyze this for non unique parents
bool RegrAdd::can_prune(QueryEval::InitCollection &collection, MircroNodeId id) {
    // TODO: hacky
    jog.adjust_for_new_nodes(subquery_node_manager, sq_node_amount_on_last_extension);
    jog.adjust_for_new_nodes(subquery_result_cache, sq_node_amount_on_last_extension);
    sq_node_amount_on_last_extension = jog.get_node_arr_am();

    auto *node = get(id);
    do {
#ifdef RADD_EXPLORATION_SUPSET_CHECK_DEBUG_PRINT
        std::cout << "Performing subset check for:" << std::endl << "[explored]   ";
        //TODO: wrap below in function
        InitCollection atoms_underlying;
        std::vector<ll> ignore;
        jog.collect_atoms(atoms_underlying, QueryEval::NO_NODE, ignore, &get(id)->assoiciated_query);
        task.dump_dl_repr(atoms_underlying, std::cout);
        std::cout << std::endl << "[curent sup] ";
        task.dump_dl_repr(collection, std::cout);
        std::cout << std::endl;
#endif
        if (is_supset(collection, node->assoiciated_query)) {
#ifdef RADD_EXPLORATION_SUPSET_CHECK_DEBUG_PRINT
            std::cout << "Is subset." << std::endl << std::endl;
#endif
            subquery_node_manager.reset();
            subquery_result_cache.reset();
            return true;
        } else {
#ifdef RADD_EXPLORATION_SUPSET_CHECK_DEBUG_PRINT
            std::cout << "Is not subset." << std::endl << std::endl;
#endif
        }
    } while ((node = next_micro_parent(node)));

    subquery_node_manager.reset();
    subquery_result_cache.reset();
    return false;
}

void RegrAdd::potential_backwards_pass(HELP::MarcroNodeId macro) {
    auto *node = get(macro);

    if (was_explored(node->common) && !node->unfulfilled_count) {
        backwards_pass(macro, -1, NO_MICRO_NODE);
    }
}

void RegrAdd::potential_backwards_pass(HELP::MircroNodeId micro) {
    auto *node = get(micro);

    if (was_explored(node->common) &&  node->common.fulfilled_cost != std::numeric_limits<ll>::max()) {
        backwards_pass(micro, node->common.fulfilled_cost);
    }
}

}
