#ifndef HELP_CORE_JOIN_ORDER_GRAPH_H
#define HELP_CORE_JOIN_ORDER_GRAPH_H

#include "../join_and_project.h"
#include <map>
#include <unordered_map>
#include <boost/variant.hpp>
#include "nodes.h"
#include "init_nodes.h"
#include "../table_representation/table_cache.h"
#include "../join_order_generator/join_order_generator.h"
#include "start_node_manager.h"
#include "../../utils/hashes.h"

namespace HELP {

// TODO: match static initializer tables, this way merge const and equality initializer

namespace QueryEval {
// TODO: what is this? - relation to helmert grounding, lifted successor generation and modern query optimization techniques
// TODO: call it like binary and forward only

// TODO: feels wrong to be here
struct InitColId {
    ll predicate;
    ll row_num;

    bool operator==(const InitColId& rhs) const = default;
};

struct PredAndRow {
    ll pred;
    JoinTable::Row row;

    bool operator==(const PredAndRow& rhs) const = default;
};
}
}

namespace std {
template <>
struct hash<HELP::QueryEval::InitColId> {
    std::size_t operator()(const HELP::QueryEval::InitColId &n_id) const noexcept
    {
        return n_id.predicate << (sizeof(HELP::ll)-8) | n_id.row_num;
    }
};
}

namespace std {
template <>
struct hash<HELP::QueryEval::PredAndRow> {
    std::size_t operator()(const HELP::QueryEval::PredAndRow &prow) const noexcept
    {
        return std::hash<HELP::QueryEval::JoinTable::Row>{}(prow.row) ^ (prow.pred << (sizeof(HELP::ll)-8));
    }
};
}

namespace HELP {
namespace QueryEval {

    // TODO: feels wrong to be here
    class InitColLookup {
        std::unordered_map<PredAndRow, InitColId>
            mapping;  // TODO (important): should probably just have one map per pred

    public:
        InitColLookup(InitCollection &init_collection)
        {
            for (auto &[pred, table] : init_collection.predicate_init) {
                ll row_id = 0;
                for (auto &row : table) {
                    mapping.emplace(PredAndRow{pred, row}, InitColId{pred, row_id++});
                }
            }
        }

        InitColId look_up(ll predicate, JoinTable::Row row)
        {
            assert(mapping.contains({predicate, row}));
            return mapping.at({predicate, row});
        }
    };

    //TODO: reorder functions by purpose
    class JoinOrderGraph : public NodeLookup { //TODO: the nodelookup should never exist
        ll req_id_count = 0;
        ll arr_id_count = 0; //TODO: should probably subst. ll with appropriate wrapper
        ll last_request; //TODO: -1?, and then also assert that not called get with -1?
        std::vector<MergeNode> merge_nodes;
        std::vector<InitNode> init_nodes;
        std::vector<ReOrderNode> reorder_nodes;

        /* RIP hacky_temorary_cache for side computations - in cased this is needed again: f843b933ab017ceb44b596229759c231cf0224c0*/
        DBInfo db_info_copy; //TODO: should this really be here?

        //TODO: typing seems wrong //TODO: use unordered map? //TODO important //TODO: important
        std::unordered_map<ll, std::map<std::vector<ll>, NodeId>> predicate_init_node; // (predicate id, par permutation) -> initializer node id
        std::unordered_map<ll, std::pair<std::unordered_map<ll, ll>, NodeId>> requests; //TODO: clean up types

        NodeId create_new_predicate_initializer(ll predicate_id, std::vector<ll> &pos_order, DBInfo &info);
        NodeId create_new_predicate_initializer(ll predicate_id, ll predicate_arity, DBInfo &info);
        NodeId get_predicate_initializer(ll predicate_id, ll predicate_arity, std::vector<ll> &reorder, DBInfo &info);

        NodeId create_join_node(NodeId left, NodeId right,
                               std::vector<PositionForward> &left_vars_to_forward,
                               std::vector<PositionForward> &right_vars_to_forward,
                               std::vector<PositionMerge> &join_rules,
                               std::unordered_map<ll, ll> *var_to_pos,
                                const std::unordered_map<ll, ll> *original_l_var_to_pos,
                                const std::unordered_map<ll, ll> *original_r_var_to_pos);

        bool to_has_from(NodeId from, NodeId to);

        void dump_propositional_repr(NodeId node); // debug only

        //TODO: naming below is a mess
        NodeId build_join_by_rule(NodeId from, NodeId to, std::unordered_map<ll, ll> &f_var_to_pos, std::unordered_map<ll, ll> &t_var_to_pos, const AnnotatedJoinOrderElement &el, std::unordered_map<ll, ll> &result_var_to_pos, bool try_match_only);
        void init_if_needed(Query &query, ll id, std::vector<NodeId> &current_node_map, std::vector<std::unordered_map<ll, ll>> &current_var_to_pos_map,  std::map<ll, NodeId> *predicate_node_map, DBInfo &info);
        NodeId build_initialize(Atom &atom, std::vector<ll> &reorder, DBInfo &info);

        void compute_joined_positions(std::vector<PositionMerge>  &vars_to_join, NodeId to, NodeId from);

        void mark_r_negated(NodeId id);

        void get_new_var_to_pos(std::unordered_map<ll, ll> &new_var_to_pos,  std::unordered_map<ll, ll> &var_to_pos, NodeId end_node, NodeId part_of_merge);

        NodeId construct_reorder(QueryEval::NodeId from, std::vector<PositionMerge> &joined,
                                 std::vector<PositionForward> &forwarded, bool is_left,
                                 ll forward_offset, std::unordered_map<ll, ll> *re_reorder);
        void reorder_joins(std::vector<PositionMerge> &join_rules,
                           std::vector<PositionForward> &left_vars_to_forward,
                           std::vector<PositionForward> &right_vars_to_forward,
                           std::unordered_map<ll, ll> &var_to_pos);

        JoinGraphNode &get(NodeId id) {
            if (id.is_init()) { //TODO visitor or class structure please
                assert(id.id_num < init_nodes.size());
                return init_nodes[id.id_num];
            } else if (id.is_merge()) {
                assert(id.id_num < merge_nodes.size());
                return merge_nodes[id.id_num];
            } else if (id.is_reorder()) {
                assert(id.id_num < reorder_nodes.size());
                return reorder_nodes[id.id_num];
            } else {
                assert(false && "shouldn't happen");
                return reorder_nodes[id.id_num]; //TODO: hacky, mby return nullptr, chase for runtime err, still hacky but better
            }
        }

        void merge(MergeNode &node, TableCache &cache);
        void reorder(ReOrderNode &node, TableCache &cache);

        void verify_binary(NodeId end);
        void verify_reorder(NodeId match, NodeId original_reorder, bool is_left, std::vector<PositionMerge> &joined);

        void build_simple(Query &query, AnnotatedJoinOrder &join_order, DBInfo &info, std::map<ll, NodeId> *predicate_node_map=nullptr);
        void try_optimized_build(std::vector<const AnnotatedJoinOrderElement *> &delayed_elements, std::vector<NodeId> &current_node, std::vector<std::unordered_map<ll, ll>> &current_var_to_pos, const AnnotatedJoinOrderElement *first_el /*for debugging only*/);
        bool build_simple_process_je_element(std::vector<NodeId> &current_node, std::vector<std::unordered_map<ll, ll>> &current_var_to_pos, const AnnotatedJoinOrderElement &join, bool try_match_only);
        void create_and_mark_result(NodeId last_node, Parameters &result_pars, std::unordered_map<ll, ll> &old_var_to_pos);
        void mark_result(ll request_id, std::unordered_map<ll, ll> &var_to_pos, NodeId last_node);

        ll get_pos_size(NodeId n_id); //for debbugging only

        /* TODO: NodeId build_reducer(NodeId node_id, std::unordered_set<NodeId> &extra_end_nodes); implment this, draft present @6e19b62939484ddf36d70a982acf1eabc3cd1667*/

        std::unordered_map<NodeId, std::unordered_map<std::vector<PositionForward>, NodeId>> reorder_matches; // (from_node, reordering) -> reorder_node_id //TODO: seems obsolete
        std::unordered_map<AdaptiveReorderHash, NodeId> adaptive_reorder_matches;
        std::unordered_map<MergeNodeIdentifierHashKey, NodeId> merge_matches;
        std::unordered_map<CreateJoinNodeCacheHash, NodeId> join_matches;
        std::unordered_map<MultipleJoinMatchHash, std::tuple<NodeId,std::unordered_map<ll, ll>,std::unordered_map<ll, ll>>> multiple_join_matches; //TODO: make struct {matched node, to_reorder (means reorder from original to posiitons to new to positions), first_from_reorder (means reorder from first from elements positions to to positions)}

        void try_match(NodeId node, std::vector<InitCollection> &result,
                       TableCache &cache, ll match_bound, InitColLookup &lookup);
        NodeId mark_head_positions_forward(NodeId node,
                                           ll node_table_id, // new table for current node
                                           TableCache &table_cache, // global table cache
                                           std::vector<JoinTable::Table> &hacky_table_cache, // local table cache
                                           bool &worked,
                                           JoinOrderGraph &flat_jog,
                                           std::unordered_map<NodeId, std::vector<InitColId>> &cache_entries_as_init_col_ids,
                                           std::vector<NodeId> &variable_renaming,
                                           InitColLookup &lookup);
        NodeId extract_reduced_flat_jog_from_this(JoinOrderGraph &flat_jog,
                                           NodeId node,
                                           std::unordered_map<NodeId, std::vector<InitColId>> &cache_entries_as_init_col_ids,
                                           TableCache &cache,
                                           std::vector<NodeId> &variable_renaming,
                                           InitColLookup &lookup);

        void validate_result_leads_to(std::unordered_map<ll, ll> original_var_to_pos,
                                 NodeId transformer_node,
                                 std::unordered_map<ll, ll> &transformed_var_to_pos,
                                      std::unordered_set<ll> &vars_covered,
                                      std::unordered_set<ll> &pos_covered,
                                      NodeId m_node,
                                      bool is_left);
        void validate_join_node(QueryEval::NodeId left, QueryEval::NodeId right, QueryEval::NodeId match,
                                               std::unordered_map<ll, ll> &l_var_to_pos,
                                               std::unordered_map<ll, ll> &r_var_to_pos,
                                               std::unordered_map<ll, ll> &result_var_to_pos);
        void validate_leads_to(NodeId from, NodeId to, bool over_left);

        void sort_by_length_decr(std::vector<NodeId> &nodes);

        bool is_valid_root(NodeId node);
        void insert_valid_nodes(NodeId node, std::vector<NodeId> &valid_last_nodes, std::unordered_set<NodeId> &seen, TableCache &cache);
        void adjust_pars(ll request, std::unordered_map<ll,ll> &old_to_new);
        bool try_match_all_to(ll common_to, const std::vector<ParRef> &common_pars, std::set<ll> &je_exists, std::vector<const AnnotatedJoinOrderElement *> &delayed_elements, std::vector<NodeId> &current_nodes, std::vector<std::unordered_map<ll, ll>> &current_var_to_pos, MultipleJoinMatchHash &h, std::unordered_map<ll,ll> &first_from_pos_to_var);

        bool fully_explored(NodeId _node, TableCache &cache);

        void adjust_pre_f(TableCache &cache, NodeId id, std::queue<NodeId> &q);

        auto &get_last_var_map() {
            return requests[get_last_request()].first;
        }

        ll next_req_id() {
            return req_id_count++;
        }
    public:
        JoinOrderGraph(Query &query, AnnotatedJoinOrder &join_order, DBInfo &info);
        JoinOrderGraph(DBInfo &info);
        virtual ~JoinOrderGraph() = default;

        struct ExtNodeRepr {
            NodeId id;
            union {
                InitNode *init_node;
                MergeNode *merge_node;
                ReOrderNode *reorder_node;
            } ptr;
        };
        std::vector<ExtNodeRepr> get_nodes(); //TODO: void


        void collect_atoms(InitCollection &init_collection, NodeId int_node, std::vector<ll> &vars_for_init,
                           NodeId *single_request=nullptr,
                           std::unordered_map<ll, std::vector<ll>> *fail_annotation=nullptr, //TODO: can probably remove this
                           TableCache *result_cache=nullptr,  //TODO: can probably remove this
                           std::unordered_map<QueryEval::NodeId, std::vector<std::pair<ll, ll>>> *jog_query_repr_mapping=nullptr);

        void extract_fail_nodes(TableCache &cache, NodeId from, std::unordered_map<QueryEval::NodeId, std::vector<ll>> &fail_annotation);

        // GroundAtomCol for states, InitCollection for variable mappings
        template<typename t> //TODO clean this up, pragma
        void evaluate(t &init_info, TableCache &cache, StartNodeManager &start_node_manager, DBInfo &db_info,
                      std::vector<NodeId> *last_nodes=nullptr,
                      std::unordered_set<NodeId> *increased=nullptr); //TODO: make null/not null template option

        //TODO sort query beforehand, maybe static first, then const atom
        //TODO minimize reorder nodes --- especially if no reorder needed
        //TODO merge;
        //TODO minimize;
        //TODO extend to full reducer;

        template<typename T>
        void adjust_for_new_nodes(T &container, size_t last_node_amount) {
            container.extend(arr_id_count);

            //TODO: i think only static is adjusted here, so could also just cache these
            for (ll i = init_nodes.size() - 1; i >= 0; i--) {
                auto &node = init_nodes.at(i);
                if (node.arr_id < last_node_amount) {
                    break;
                }
                container.register_node(InitializerId(i), node, this);
            }
            for (ll i = merge_nodes.size() - 1; i >= 0; i--) {
                auto &node = merge_nodes.at(i);
                if (node.arr_id < last_node_amount) {
                    break;
                }
                container.register_node(MergeId(i), node, this);
            }
            for (ll i = reorder_nodes.size() - 1; i >= 0; i--) {
                auto &node = reorder_nodes.at(i);
                if (node.arr_id < last_node_amount) {
                    break;
                }
                container.register_node(ReorderId(i), node, this);
            }
        }

        bool had_result(TableCache &cache) {
            assert(requests.size() == 1);
            return !cache.get_node_table(requests.begin()->second.second.id_num).empty();
        }

        TableCache create_table_cache_simple(DBInfo &info) {
            return TableCache(get_node_arr_am(), info.predicate_amount);
        }

        DeltaTableCache create_delta_table_cache_simple(DBInfo &info) {
            return DeltaTableCache(get_node_arr_am(), info.predicate_amount);
        }

        template<bool use_static>
        RegressiveStartNodeManger<use_static> create_recursive_start_nodes_manager(DBInfo &info) {
            return RegressiveStartNodeManger<use_static>(get_node_arr_am(),  info.predicate_amount);
        }

        RegressiveStartNodeMangerStandard create_recursive_start_nodes_manager_standard(DBInfo &info) {
            return  create_recursive_start_nodes_manager<true>(info);
        }
        RegressiveStartNodeMangerNoExtraStaticHandle create_recursive_start_nodes_manager_no_static(DBInfo &info) {
            return create_recursive_start_nodes_manager<false>(info);
        }

        SimplePredInitializerWrapManager create_simple_start_node_manager() {
            return SimplePredInitializerWrapManager(&predicate_init_node);
        }

        ll get_last_request() {
            return last_request;
        }

        ll get_node_arr_am() {
            return arr_id_count;
        }

        NodeId get_last_request_node() {
            return requests.at(get_last_request()).second;
        }

        std::pair<std::unordered_map<ll, ll>&, JoinTable::Table&> get_result(ll request_id, TableCache &cache) { //Return var_map, results
            assert(request_id < requests.size());
            auto &request = requests[request_id];

            return {request.first, cache.get_node_table(get(request.second).arr_id)};
        }

        bool is_static(NodeId id) override {
            return get(id).is_static;
        }

        /*
         * TODO: never got around to finish implementing this, by now should be easy via temporary_table and start_node manager
         * TODO: though we would probably have to add something that bounds the extraction in case it gets to large for small subqueries
         * merging variant inspired by Chandra Chekuri, Anand Rajaraman: Conjunctive query containment revisited
         * https://core.ac.uk/download/pdf/81184999.pdf --
         */
        void add_new_query(Query &query, DBInfo &info);
        void add_new_query(InitCollection &init_collection, std::vector<ll> &new_result_pars, std::unordered_map<ll,ll> &old_to_new, DBInfo &info);

        ll arr_get(NodeId id) override {
            return get(id).arr_id;
        }

        template<bool use_static>
        void mark_for_exploration(NodeId id, RegressiveStartNodeManger<use_static> &manger);

#ifndef NDEBUG
        // these functions shall be used for debugging only

        auto &pub_get(NodeId id) {
            return get(id);
        }
#endif
    };
}

}

#include "join_order_graph.tpp"

#endif //HELP_CORE_JOIN_ORDER_GRAPH_H
