#ifndef HELP_CORE_REGR_ADD_H
#define HELP_CORE_REGR_ADD_H

#include "heuristic.h"
#include "../conj_query_eval/join_realizer/join_order_graph.h"
#include "../conj_query_eval/join_realizer/start_node_manager.h"
#include "../utils/primitive_type_wrapper.hpp"
#include "../utils/common_func.h"
#include <unordered_map>

/*
 * TODO very important: If goals are split up adjust cost for micronodes according to other parts,
 *  especially, if one becomes a dead-end, stop exploring
 *
 * TODO
 *  Also when evaluates to true, cancel exploration
 */

// debug flags
// #define RADD_EXPLORATION_BOUND_CHANGE // print explore cost bound changes

namespace HELP {

//TODO: the structs below could be part of class RegrAdd
class MarcroNodeId : public primitive_type_wrapper<ll> { //TODO: fix typo
    using primitive_type_wrapper<ll>::primitive_type_wrapper;
};

class MircroNodeId : public primitive_type_wrapper<ll> { //TODO: fix typo
    using primitive_type_wrapper<ll>::primitive_type_wrapper;
};

inline MircroNodeId NO_MICRO_NODE(-1);
inline MarcroNodeId NO_MACRO_NODE(-1);

template<typename>
struct RegrAddNoNode;

template<>
struct RegrAddNoNode<MircroNodeId> {
    static inline MircroNodeId NO_NODE = NO_MICRO_NODE;
};

template<>
struct RegrAddNoNode<MarcroNodeId> {
    static inline MarcroNodeId NO_NODE = NO_MACRO_NODE;
};

class CommonNodeId {
    ll _is_macro: 1;
    ll node_value: sizeof(ll) * CHAR_BIT - 1; //TODO: use llminus instead
public:
    CommonNodeId(bool is_macro, ll node_value) : _is_macro(is_macro), node_value(node_value) {}

    CommonNodeId(MarcroNodeId id) : _is_macro(true), node_value(id.get()) {}

    CommonNodeId(MircroNodeId id) : _is_macro(false), node_value(id.get()) {}

    CommonNodeId(const CommonNodeId &) = default;

    CommonNodeId() = default;

    CommonNodeId &operator=(const CommonNodeId &other) = default;

    auto operator<=>(const CommonNodeId &other) const {
        if (_is_macro != other._is_macro) {
            return /*TODO: !*/_is_macro ? std::strong_ordering::less : std::strong_ordering::greater; // important: micro smaller to explore directly after macro exploration
        } else {
            if (node_value < other.node_value) {
                return std::strong_ordering::less;
            } else if (node_value > other.node_value) {
                return std::strong_ordering::greater;
            } else {
                return std::strong_ordering::equal;
            }
        }
    }

    bool is_macro() const {
        return _is_macro;
    }

    MircroNodeId get_node_as_micro() const {
        return node_value;
    }

    MarcroNodeId get_node_as_macro() const {
        return node_value;
    }
};

}

namespace std { //TODO: get rid of this
    template<>
    struct hash<HELP::MircroNodeId> {
        size_t operator()(const HELP::MircroNodeId &twr) const noexcept {
            return twr.get();
        }
    };
}

namespace std { //TODO: get rid of this
    template<>
    struct hash<HELP::MarcroNodeId> {
        size_t operator()(const HELP::MarcroNodeId &twr) const noexcept {
            return twr.get();
        }
    };
}

namespace HELP {

struct CommonNodeContent {
    static const ll NOT_EXPLORED = -1;
    ll fulfilled_cost;
    bool do_not_explore;
    bool _was_explored;

    ll already_further_explored_cost;
    ll current_cheapest_explore_cost;
    ll current_explore_cost_bound;
};

struct MacroNode {
    size_t unfulfilled_count;
    CommonNodeContent common;
    std::unordered_map<MircroNodeId, ll> groups_together; // micro nodes represented by macro node
    std::vector<ll> fulfilled_cost; // cost per groups_together micro node, -1 if not fulfilled yet

    std::vector<std::pair<MircroNodeId, ll>> reached_from; // collection of (MicroNode, Action) that reach MacroNode
    MircroNodeId current_cheapest_reached;

    std::unordered_map<MircroNodeId, ll> reached_this_iteration;
    std::unordered_map<MircroNodeId, ll> reached_this_iteration_todo;
};

struct MicroNode { //TODO: reduce object duplication between nodes via other_node_type
    CommonNodeContent common;
    MarcroNodeId parent;
    QueryEval::NodeId assoiciated_query; // the id of the query in join order graph, wrapped by this struct
    QueryEval::InitCollection query_representation;
    std::unordered_map<ll, std::vector<std::vector<MarcroNodeId>>> substitutes_to; // query_representation_entry -> possible macro node creations
    std::unordered_map<ll, std::vector<bool>> substitution_computed; // marks if according substitutes_to was computed
    std::unordered_map<QueryEval::NodeId, std::vector<std::pair<ll, ll>>> jog_query_repr_mapping; // links atoms from query representation to join order graph nodes
    std::unordered_map<MarcroNodeId, ll> action_annoation; // crated macro node -> cheapest action id

    bool substitued_to_this_iteration_created;
    std::vector<MarcroNodeId> substitued_to_this_iteration;

    std::vector<MarcroNodeId> reached_from;
    MarcroNodeId current_cheapest_reached;

    std::unordered_set<MarcroNodeId> reached_this_iteration;
    std::unordered_set<MarcroNodeId> reached_this_iteration_todo;

    ll max_var; // max var in query representation
};

template<typename T>
struct other_node_type;

template<>
struct other_node_type<MarcroNodeId> {
    typedef MircroNodeId OTHER_T;
};

template<>
struct other_node_type<MircroNodeId> {
    typedef MarcroNodeId OTHER_T;
};

template<typename T>
class ExploreRegistry {
    class SingleManager {
        std::unordered_map<typename other_node_type<T>::OTHER_T, ll> cost_map; // maps parent to cheapest explore cost for parent
        std::priority_queue<std::pair<ll, typename other_node_type<T>::OTHER_T>> prio_q;

    public:
        void notify_removed(typename other_node_type<T>::OTHER_T from) {
            assert(cost_map.contains(from));
            cost_map[from] = -1;
        }

        void try_register(typename other_node_type<T>::OTHER_T from, ll cost) {
            if (!cost_map.contains(from) || cost_map.at(from) > cost) {
                cost_map.emplace(from, cost);
                prio_q.push({cost, from});
            }
        }

        std::pair<typename other_node_type<T>::OTHER_T, ll> get_alternative() {
            while (!prio_q.empty()) {
                auto [cost, node] = prio_q.top();
                prio_q.pop();
                auto map_cost = cost_map.at(node);
                if (map_cost == cost) {
                    assert(map_cost != -1);
                    return {node, cost};
                }
            }

            return {RegrAddNoNode<typename other_node_type<T>::OTHER_T>::NO_NODE, -1};
        }

        void reset() {
            cost_map.clear();
            prio_q = std::priority_queue<std::pair<ll, typename other_node_type<T>::OTHER_T>>();
        }
    };

    std::vector<SingleManager> managers;

public:
    void notify_new_node() {
        managers.emplace_back();
    }

    void mark_potential_later_explore(typename other_node_type<T>::OTHER_T from , T to, ll cost) {
        managers.at(to.get()).try_register(from, cost);
    }

    auto get_alternative(T node) {
        return managers.at(node.get()).get_alternative();
    }

    void reset_q_manager(T node) {
        managers.at(node.get()).reset();
    }

    void notify_removed(T node, typename other_node_type<T>::OTHER_T from) {
        managers.at(node.get()).notify_removed(from);
    }
};

class RegrAdd : public Heuristic {
    size_t node_amount_on_last_extension = 0;
    size_t sq_node_amount_on_last_extension = 0;

    std::unordered_map<QueryEval::NodeId, ll> node_to_request;
    std::vector<QueryEval::InitCollection> action_to_condition; //TODO: initcollection should probably be called "normalized query representation" or something similar
    std::vector<std::vector<ll>> predicate_to_action; // maps predicate p to actions with add p(...)

    std::priority_queue<std::pair<ll, CommonNodeId>, std::vector<std::pair<ll, CommonNodeId>>, std::greater<std::pair<ll, CommonNodeId>>> q; //TODO (important): FD uses its own implementation, should we do this here too?
    std::vector<CommonNodeId> seen_nodes;

    LiftedStripsTask &task;
    DBInfo dbInfo; //TODO rename
    QueryEval::JoinOrderGraph jog;
    QueryEval::TableCache result_cache;
    QueryEval::RegressiveStartNodeMangerStandard start_node_manager;

    QueryEval::TableCache subquery_result_cache;
    QueryEval::RegressiveStartNodeMangerNoExtraStaticHandle subquery_node_manager;

    std::vector<MacroNode> macro_nodes;
    std::vector<MicroNode> micro_nodes;
    std::vector<MircroNodeId> jog_to_micro;
    std::unordered_map<std::vector<QueryEval::NodeId>, MarcroNodeId> jog_to_macro; //TODO: clone boost hashes from FD/powerlifted?

    ExploreRegistry<MarcroNodeId> macro_node_explore_registry;
    ExploreRegistry<MircroNodeId> micro_node_explore_registry;

    ExploreRegistry<MarcroNodeId> &get_registry_by_t(MarcroNodeId) {
        return macro_node_explore_registry;
    }

    ExploreRegistry<MircroNodeId> &get_registry_by_t(MircroNodeId) {
        return micro_node_explore_registry;
    }

    MarcroNodeId goal;

    QueryEval::NodeId retrieve_and_register_last_node();

    MarcroNodeId create_goal();
    void extend_condition(QueryEval::InitCollection &original, QueryEval::InitCollection &extension, QueryEval::InitCollection &result, ll p_to_replace, ll p_args_pos_to_replace, std::map<ll, ll> &extension_var_map);
    void create_var_map(std::vector<ll> &new_add_args, QueryEval::InitCollection &achiever, ll base_max, std::map<ll,ll> &result);
    ll combine_cost(ll cost, ll action);
    bool fulfilled(QueryEval::NodeId node, StripsState &state);
    void node_state_eval_print(QueryEval::NodeId node, StripsState &state);
    void add_successor_to_q(MircroNodeId node_id, ll predicate, ll i);
    bool try_set_cheaper_cost(MircroNodeId micro_id, ll act, MarcroNodeId macro_id);
    bool try_set_cheaper_cost(MarcroNodeId macro_id, MircroNodeId micro_id);

    MicroNode *get(MircroNodeId id) {
        assert(id != NO_MICRO_NODE);
        assert(micro_nodes.size() > id.get());
        return &micro_nodes.at(id.get());
    }
    MacroNode *get(MarcroNodeId id) {
        assert(macro_nodes.size() > id.get());
        return &macro_nodes.at(id.get());
    }

    void reset(CommonNodeContent& content);
    void reset(MicroNode &node);
    void reset(MacroNode &node);

    void explore(MircroNodeId node_id, StripsState &state);
    void explore(MarcroNodeId node_id);

    bool valid_node(CommonNodeId id) {
        if (id.is_macro()) {
            return (id.get_node_as_macro().get() < macro_nodes.size());
        } else {
            return (id.get_node_as_micro().get() < micro_nodes.size());
        }
    }

    MircroNodeId construct(QueryEval::NodeId node_id, MarcroNodeId parent);

    MarcroNodeId construct(QueryEval::InitCollection &query);
    MarcroNodeId construct(std::vector<QueryEval::NodeId> &node_collection);

    void backwards_pass(MircroNodeId node_id, ll cost);
    void backwards_pass(MarcroNodeId node_id, ll cost, MircroNodeId from);

    void split_init_by_vars(std::vector<QueryEval::InitCollection> &split_up_result, QueryEval::InitCollection &init_collection);

    void get_next_nodes(MircroNodeId node_id);
    std::vector<MarcroNodeId> &generate_successor(MircroNodeId node_id, ll predicate, ll i);

    MarcroNodeId get_macro_id(MacroNode &node) {
        return MarcroNodeId{get_vec_index(macro_nodes, node)};
    }

    MircroNodeId get_micro_id(MicroNode &node) {
        return MircroNodeId{get_vec_index(micro_nodes, node)};
    }

    ll get_cost(MarcroNodeId node_id) {
        auto *node = get(node_id);
        return node->unfulfilled_count ? std::numeric_limits<ll>::max() : node->common.fulfilled_cost;
    }

    MicroNode *next_micro_parent(MicroNode *node);
    bool can_prune(QueryEval::InitCollection &collection, MircroNodeId id);
    bool is_supset(QueryEval::InitCollection &collection, QueryEval::NodeId node); // by checking \Phi(node.initcol) \subseteq macro

    void register_q_exploral(MircroNodeId, MarcroNodeId, ll act_id);
    void register_q_exploral(MarcroNodeId, MircroNodeId);

    void potential_backwards_pass(MarcroNodeId);
    void potential_backwards_pass(MircroNodeId);

    template<typename T>
    void remove_old_predecessor(T node, typename other_node_type<T>::OTHER_T from) {
        auto *_node = get(node);
        assert(_node->reached_this_iteration.contains(from));
        _node->reached_this_iteration.erase(from);
        _node->reached_this_iteration_todo.erase(from);
    }

    ll get_cost(MircroNodeId node, MarcroNodeId parent) {
        return get(parent)->common.current_cheapest_explore_cost;
    }

    ll get_cost(MarcroNodeId node, MircroNodeId parent) {
        assert(get(node)->reached_this_iteration.contains(parent));
        return get(parent)->common.current_cheapest_explore_cost + task.get_actions().at(get(node)->reached_this_iteration.at(parent)).get_cost();
    }

    template<typename T>
    bool detach_best_supporter(T node, bool allow_removal) {
        auto old_from = get(node)->current_cheapest_reached;
        get_registry_by_t(node).notify_removed(node, old_from); //TODO move
        CommonNodeContent &common = get_common(node);

        if (allow_removal) {
            // do remove
            remove_old_predecessor(node, old_from);
        } else  {
            get_registry_by_t(node).mark_potential_later_explore(old_from, node, get_cost(node, old_from));
        }

        auto [alt_from, alt_cost] = get_registry_by_t(node).get_alternative(node);
        assert(alt_from != get(node)->current_cheapest_reached);
        assert(common.fulfilled_cost == std::numeric_limits<ll>::max());

        if (alt_from != RegrAddNoNode<T>::NO_NODE) {
            q.push({alt_cost, node});

            get(node)->current_cheapest_reached = alt_from;
            common.current_explore_cost_bound = get(alt_from)->common.current_explore_cost_bound;
            common.current_cheapest_explore_cost = alt_cost;
            return false;
        } else {
            assert(allow_removal);
            reset(*get(node));
            return true;
        }
    }

    template<typename T>
    bool detach_and_potentially_trigger_readd(T node) {
        auto *_node = get(node);
        CommonNodeContent &common = get_common(node);

        if (common.current_explore_cost_bound != get_common(_node->current_cheapest_reached).current_explore_cost_bound) {
            return false; //TODO: does this make sense?
        }

        if (common.fulfilled_cost != std::numeric_limits<ll>::max()) {
            return true;
        }

        bool allow_removal = detach_and_potentially_trigger_readd(_node->current_cheapest_reached);
        bool no_parent = detach_best_supporter(node, allow_removal);

        return no_parent;
    }

    template<typename T>
    bool can_skip(T node, ll pop_cost) {
        if (get(node)->common.do_not_explore) {
            return true;
        }

        if (was_explored(node)) {
            return true;
        }

        if (get(node)->common.already_further_explored_cost > pop_cost) {
            q.push({get(node)->common.already_further_explored_cost, node});
            return true;
        }

        if (pop_cost >= get(node)->common.current_explore_cost_bound) {
            detach_and_potentially_trigger_readd(node);
            return true;
        }

        return false;
    }

    CommonNodeContent &get_common(MircroNodeId id) {
        return get(id)->common;
    }

    CommonNodeContent &get_common(MarcroNodeId id) {
        return get(id)->common;
    }

    CommonNodeContent &get_common(CommonNodeId id) {
        if (id.is_macro()) {
            return get_common(id.get_node_as_macro());
        } else {
            return get_common(id.get_node_as_micro());
        }
    }

    std::vector<std::pair<MarcroNodeId, ll>> get_to(MircroNodeId node) { // TODO: should not filter and crate this every time here
        std::vector<std::pair<MarcroNodeId, ll>> res;
        assert(get(node)->substitued_to_this_iteration_created);

        for (auto macro : get(node)->substitued_to_this_iteration) {
            res.emplace_back(macro, task.get_actions()[get(node)->action_annoation.at(macro)].get_cost());
        }

        return res;
    }

    bool was_explored(CommonNodeContent &cont) {
        return cont._was_explored;
    }

    template<typename N>
    bool was_explored(N node) { //TODO: generalize this patter
        return was_explored(get_common(node));
    }

    std::vector<std::pair<MircroNodeId, ll>> get_to(MarcroNodeId node) {
        std::vector<std::pair<MircroNodeId, ll>> res;
        for (auto &[to, _] : get(node)->groups_together) {
            res.emplace_back(to, 0);
        }
        return res;
    }

    template<typename T>
    void mark_minimal_achiever_cost(T node, ll cost) {
        CommonNodeContent &common = get_common(node);

        if (common.current_explore_cost_bound > cost) {
            common.current_explore_cost_bound = cost;

            if (!was_explored(node)) {
                return;
            }

            auto to_ids = get_to(node);
            ll further_explored = std::numeric_limits<ll>::max();
            for (auto &[new_to, add_cost] : to_ids) {
                auto *to_node = get(new_to);
                if (to_node->current_cheapest_reached == node) {
                    mark_minimal_achiever_cost(new_to, cost);
                }
                further_explored = std::min(further_explored, get_common(new_to).already_further_explored_cost + add_cost);
            }
            common.already_further_explored_cost = further_explored;
        }
#ifdef RADD_EXPLORATION_BOUND_CHANGE
        std::cout << "Explore cost bound for " << (std::is_same<T, MircroNodeId>::value ? "micro" : "macro") << " node " << node.get() << " is: " << get(node)->common.current_explore_cost_bound << " (updated by forward pass)" << std::endl;
#endif
    }

public:
    RegrAdd(LiftedStripsTask &task);
    virtual ~RegrAdd() = default;
    ll compute(StripsState &state) override;
};

}


#endif //HELP_CORE_REGR_ADD_H
