#ifndef HELP_CORE_START_NODE_MANAGER_H
#define HELP_CORE_START_NODE_MANAGER_H

#include "../../utils/type_defs.h"
#include "../../task_formulations/db_propositional_shared.h"
#include "../join_realizer/init_nodes.h"
#include "../table_representation/table_cache.h"

#include <queue>

namespace HELP { 

namespace QueryEval {
    struct StartNodeContainer {
        std::vector<ll> predicates_used_collection; // TODO: is memorizing in this way actually smart?, we will not visit in memory order
        std::vector<std::vector<NodeId>> start_nodes_per_pred;

        StartNodeContainer(size_t pred_am) : start_nodes_per_pred(pred_am) {}

        StartNodeContainer(StartNodeContainer &other) : start_nodes_per_pred(other.start_nodes_per_pred.size()) {
            for (auto p : other.predicates_used_collection) {
                start_nodes_per_pred[p] = other.start_nodes_per_pred[p];
            }
            predicates_used_collection = other.predicates_used_collection;
        }

        StartNodeContainer& operator=(StartNodeContainer& other) {
            for (auto p : predicates_used_collection) {
                start_nodes_per_pred.at(p).clear();
            }

            start_nodes_per_pred.resize(other.start_nodes_per_pred.size());
            predicates_used_collection = other.predicates_used_collection;

            for (auto p : predicates_used_collection) {
                other.start_nodes_per_pred[p] = start_nodes_per_pred[p];
            }

            return other;
        }

        void register_node(ll pred, NodeId node) {
            if (start_nodes_per_pred.at(pred).empty()) {
                predicates_used_collection.push_back(pred);
            }

            start_nodes_per_pred.at(pred).push_back(node);
        }

        bool empty() {
            return predicates_used_collection.empty();
        }

        void clear() {
            for (auto p : predicates_used_collection) {
                start_nodes_per_pred[p].clear();
            }

            predicates_used_collection.clear();
        }

        void collect(GroundAtomCol &state, std::queue<NodeId> &q) { //TODO (important) let iteration depend on if state or startnodes is smaller
            for (auto p : state.get_non_empty_preds()) {
                for (auto &node : start_nodes_per_pred[p]) {
                    q.push(node);
                }
            }
            clear(); //TODO important: is this really generally expected behavior?
        }

        void collect(InitCollection &init_col, std::queue<NodeId> &q) { //TODO: should just also use DBState here
            for (auto p : predicates_used_collection) {
                if (init_col.predicate_init.contains(p) && !init_col.predicate_init.at(p).empty()) {
                    for (auto &node : start_nodes_per_pred[p]) {
                        q.push(node);
                    }
                }
            }
            clear(); //TODO important: is this really generally expected behavior?
        }
    };

    class StartNodeManager {
    protected:
        std::vector<NodeId> additional_start_nodes;
    public:
        auto &get_additional_start_nodes() { return additional_start_nodes; }
        virtual void collect_start_nodes(GroundAtomCol &info, std::queue<NodeId> &q) = 0;
        virtual void collect_start_nodes(InitCollection &info, std::queue<NodeId> &q) = 0;
        virtual bool should_visit(ll id) = 0;
        virtual bool does_consider_static() {
            return true;
        }
        virtual std::vector<QueryEval::NodeId> &get_relevant_node_edges(QueryEval::JoinGraphNode &node) {
            return node.edges;
        }
        virtual void reset_relevant_node_edges(QueryEval::JoinGraphNode &node) {

        }

        virtual ~StartNodeManager() = default;
    };

    class SimplePredInitializerWrapManager : public StartNodeManager { //TODO: seems very hacky
        std::unordered_map<ll, std::map<std::vector<ll>, NodeId>> *predicate_init_node_link; //TODO: use const & //TODO: wrap type heere and in join_order_graph.h -- this is just a link to predicate_init_node
    public:
        SimplePredInitializerWrapManager(std::unordered_map<ll, std::map<std::vector<ll>, NodeId>> *predicate_init_node_link) : predicate_init_node_link(predicate_init_node_link) {}
        virtual ~SimplePredInitializerWrapManager() = default;

        virtual void collect_start_nodes(GroundAtomCol &state, std::queue<NodeId> &q) override {
            for (auto p : state.get_non_empty_preds()) {
                if (predicate_init_node_link->contains(p)) {
                    for (auto &[_, node] : predicate_init_node_link->at(p)) {
                        q.push(node);
                    }
                }
            }
        }

        virtual void collect_start_nodes(InitCollection &info, std::queue<NodeId> &q) override { //TODO: combine with above
            for (auto &[p, _] : info.predicate_init) { // TODO: proper _?
                if (predicate_init_node_link->contains(p)) {
                    for (auto &[_, node] : predicate_init_node_link->at(p)) {
                        q.push(node);
                    }
                }
            }
        }

        bool should_visit(ll id) override {
            return true;
        }
    };

    template<bool extra_static_handling=true> //TODO: probably should optimize more given this condition
    class RegressiveStartNodeManger : public StartNodeManager {
        struct RegDataEntry {
            ll predicate;
            NodeId id;
        };

        unsigned long long time_stamp=0; //TODO: move to parent
        unsigned long long collect_time_stamp=0; // to track how many time collec_start_nodes was called
        StartNodeContainer current_start_nodes;
        AutoResetVec<char> used; //TODO: should be bool
        AutoResetVec<char> just_used; //TODO: should be bool
        AutoResetVec<std::vector<QueryEval::NodeId>> current_edges; //TODO: should be bool
        std::unordered_map<ll, RegDataEntry> reg_data;

        bool is_start_node(ll id) {
            return reg_data.contains(id);
        }

    public:
        RegressiveStartNodeManger(ll table_am, ll pred_am) : used(time_stamp, table_am), just_used(collect_time_stamp, table_am), current_edges(time_stamp, table_am), current_start_nodes(pred_am) {}
        virtual ~RegressiveStartNodeManger() = default;

        void collect_start_nodes(InitCollection &info, std::queue<NodeId> &q) override {
            current_start_nodes.collect(info, q);
            collect_time_stamp++;
        }
        void collect_start_nodes(GroundAtomCol &info, std::queue<NodeId> &q) override {
            current_start_nodes.collect(info, q);
            collect_time_stamp++;
        }
        void register_node(NodeId id, JoinGraphNode &node, NodeLookup *node_lookup);
        void register_predecessor(NodeId to, NodeId from, NodeLookup *node_lookup);
        void extend(size_t new_size);
        bool set_used(ll id);
        bool should_visit(ll id) override;
        void reset() {
            time_stamp++;
        }
        bool was_explore_marked(ll id) {
            return used.at(id) && !just_used.at(id);
        }
        void register_additional_start_node(NodeId id) {
            additional_start_nodes.push_back(id);
        }

        bool does_consider_static() override {
            return extra_static_handling;
        }

        virtual std::vector<QueryEval::NodeId> &get_relevant_node_edges(QueryEval::JoinGraphNode &node) override {
            assert(node.arr_id < current_edges.size());
            return current_edges.at(node.arr_id);
        }

        virtual void reset_relevant_node_edges(QueryEval::JoinGraphNode &node) override {
            assert(node.arr_id < current_edges.size());
            current_edges.at(node.arr_id).clear();
        }
    };

    using RegressiveStartNodeMangerStandard = RegressiveStartNodeManger<true>;
    using RegressiveStartNodeMangerNoExtraStaticHandle = RegressiveStartNodeManger<false>;
}

}

#include "start_node_manager.tpp"

#endif //HELP_CORE_START_NODE_MANAGER_H
