#ifndef HELP_CORE_JOINGRAPH_NODES_H
#define HELP_CORE_JOINGRAPH_NODES_H

#include "../../utils/type_defs.h"
#include <vector>
#include <set>
#include <map>
#include <unordered_map>
#include <boost/variant.hpp>
#include "../../utils/hashes.h"

namespace HELP { 

namespace QueryEval {
    enum NodeType {
        INIT,
        MERGE,
        REORDER,
        _NO_NODE
    };

    //TODO: probably the best example of "just use inheritance" -- we should restructure this at some point
    struct NodeId {
        NodeType type : 2; //TODO: static assert that enum fits
        ll id_num : 8*sizeof(ll)-2; //TODO: make this more generic with util
        //TODO static assert size of one ll

        public:

        bool is_init() const {
            return type == NodeType::INIT;
        }
        bool is_merge() const {
            return type == NodeType::MERGE;
        }
        bool is_reorder() const {
            return type == NodeType::REORDER;
        }
        NodeType get_enum_val() const {
            return type;
        }

        auto operator<=>(const NodeId&) const = default;
    };

    inline NodeId InitializerId(ll id) { return {NodeType::INIT, id}; }

    inline NodeId MergeId(ll id) { return {NodeType::MERGE, id}; }

    inline NodeId ReorderId(ll id) { return {NodeType::REORDER, id}; }

    extern NodeId NO_NODE;

    // TODO: initialization in all classes here is a total mess, try to follow one principle

    struct JoinGraphNode { //TODO: probably should just be a class and contain multiple virtual methods
        ll arr_id; //TODO: wrap ll
        ll merge_node_depth = 0; //TODO: only 0-initialize for initnodes
        bool is_static; //TODO: should wrap this and merge node depth to automatically compute while upwards construction //TODO: could just use 1 bit of another field

        std::vector<NodeId> edges;
    };

    struct PositionForward {
        ll from;
        ll to; //TODO: get rid of this

        bool operator==(const PositionForward &other) const = default;
        bool operator<(const PositionForward &other) const {
            if (from == other.from) {
                return to < other.to;
            }
            return from < other.from;
        }
    };

    struct PositionMerge {
        ll from_left;
        ll from_right;
        ll to; //TODO: get rid of this

        bool operator==(const PositionMerge &other) const = default;
        bool operator<(const PositionMerge &other) const {
            if (from_left < other.from_left) {
                return true;
            } else if (from_left > other.from_left) {
                return false;
            }

            if (from_right < other.from_right) {
                return true;
            } else if (from_right > other.from_right) {
                return false;
            }

            return to < other.to;
        }
    };

    struct ReOrderNode : public JoinGraphNode { //TODO: join not merge
        NodeId from;
        std::vector<PositionForward> re_order;
    };

    struct MergeNode : public JoinGraphNode { //TODO: join not merge
        NodeId left;
        NodeId right;

        bool right_negated = false; //TODO: make this an extra type

        std::vector<PositionForward> left_positions_forward; //TODO: just store amount
        std::vector<PositionForward> right_positions_forward; //TODO: just store amount
        std::vector<PositionMerge> joined_positions; //TODO: just store amount

        size_t get_row_size() {
            return left_positions_forward.size()
                   + right_positions_forward.size()
                   + joined_positions.size();
        }

        //TODO potential_result_pars_list; //TODO: minimize result par traces
    };

    inline std::set<ll> from_to_set(std::vector<PositionForward> &pfw) {
        std::set<ll> s;
        for (auto &[from, to] : pfw) {
            s.insert(from);
        }
        return s;
    }

    struct MultipleJoinMatchHash {
        QueryEval::NodeId to;
        std::vector<std::pair<QueryEval::NodeId, std::vector<std::pair<ll,ll>>>> froms; // pairs (node to join, positions to join)

        bool operator==(const MultipleJoinMatchHash &other) const = default;
    };

    struct CreateJoinNodeCacheHash {
        QueryEval::NodeId left;
        QueryEval::NodeId right;
        std::set<ll> left_vars_to_forward;
        std::set<ll> right_vars_to_forward;
        std::set<PositionMerge> join_rules;

        CreateJoinNodeCacheHash(QueryEval::NodeId left, QueryEval::NodeId right,
            std::vector<PositionForward> &left_vars_to_forward,
            std::vector<PositionForward> &right_vars_to_forward,
            std::vector<PositionMerge> &join_rules)
            : left(left),
              right(right),
              left_vars_to_forward(from_to_set(left_vars_to_forward)),
              right_vars_to_forward(from_to_set(right_vars_to_forward)),
              join_rules(join_rules.begin(), join_rules.end())
        {}

        bool operator==(const CreateJoinNodeCacheHash &other) const = default;
    };

    struct AdaptiveReorderHash {
        QueryEval::NodeId from;
        std::set<ll> forward;
        std::set<PositionForward> merge;

        AdaptiveReorderHash(QueryEval::NodeId from,
                                std::vector<PositionForward> &forward,
                                std::vector<PositionForward> &merge)
            : from(from),
              forward(from_to_set(forward)),
              merge(merge.begin(), merge.end())
        {}

        bool operator==(const AdaptiveReorderHash &other) const = default;
    };

    struct MergeNodeIdentifierHashKey {
        QueryEval::NodeId left;
        QueryEval::NodeId right;
        std::vector<PositionForward> left_vars_to_forward;
        std::vector<PositionForward> right_vars_to_forward;
        std::vector<PositionMerge> join_rules;

        bool operator==(const MergeNodeIdentifierHashKey &other) const = default;
    };

    class NodeLookup {
    public:
        virtual ll arr_get(NodeId id) = 0;
        virtual bool is_static(NodeId id) = 0;
        virtual ~NodeLookup() = default;
    };
}

// TODO: strictly speaking this may be wrong : https://stackoverflow.com/questions/2468708/converting-bit-field-to-int
// compare to: https://stackoverflow.com/questions/40440468/how-can-an-int-be-converted-to-unsigned-int-while-preserving-the-original-bi -- should be okay even if size is not the same
struct NodeIdToLL {
    union {
        QueryEval::NodeId node_id;
        size_t value; //TODO static assert size_t and nodeId have the same size
    };
};

}

namespace std {
template<> struct hash<HELP::QueryEval::NodeId> {
    std::size_t operator()(const HELP::QueryEval::NodeId& n_id) const noexcept {
        HELP::NodeIdToLL val;
        val.node_id = n_id;
        return val.value;
    }
};

template<> struct hash<HELP::QueryEval::PositionForward> {
    std::size_t operator()(const HELP::QueryEval::PositionForward& pf) const noexcept {
        return (pf.from << (sizeof(HELP::ll)/2)) ^ pf.to;
    }
};

template<> struct hash<HELP::QueryEval::PositionMerge> {
    std::size_t operator()(const HELP::QueryEval::PositionMerge& pf) const noexcept {
        return (pf.from_left << (sizeof(HELP::ll)*2/3)) ^ (pf.from_right << (sizeof(HELP::ll)/3)) ^ pf.to;
    }
};

template<> struct hash<HELP::QueryEval::MergeNodeIdentifierHashKey> {
    std::size_t operator()(const HELP::QueryEval::MergeNodeIdentifierHashKey& h_id) const noexcept {
        return (std::hash<HELP::QueryEval::NodeId>{}(h_id.left) << (sizeof(HELP::ll)/2))
               ^ std::hash<HELP::QueryEval::NodeId>{}(h_id.right)
               ^ std::hash<std::vector<HELP::QueryEval::PositionForward>>{}(h_id.left_vars_to_forward)
               ^ std::hash<std::vector<HELP::QueryEval::PositionForward>>{}(h_id.right_vars_to_forward)
               ^ std::hash<std::vector<HELP::QueryEval::PositionMerge>>{}(h_id.join_rules);
    }
};

template<> struct hash<HELP::QueryEval::AdaptiveReorderHash> {
    std::size_t operator()(const HELP::QueryEval::AdaptiveReorderHash& h_id) const noexcept {
        return std::hash<HELP::QueryEval::NodeId>{}(h_id.from)
               ^ std::hash<std::set<HELP::ll>>{}(h_id.forward)
               ^ std::hash<std::set<HELP::QueryEval::PositionForward>>{}(h_id.merge);
    }
};

template<> struct hash<HELP::QueryEval::MultipleJoinMatchHash> {
    std::size_t operator()(const HELP::QueryEval::MultipleJoinMatchHash& h_id) const noexcept {
        return std::hash<HELP::QueryEval::NodeId>{}(h_id.to)
               ^ std::hash<std::vector<std::pair<HELP::QueryEval::NodeId, std::vector<std::pair<HELP::ll,HELP::ll>>>>>{}(h_id.froms);
    }
};

template<> struct hash<HELP::QueryEval::CreateJoinNodeCacheHash> {
    std::size_t operator()(const HELP::QueryEval::CreateJoinNodeCacheHash& h_id) const noexcept {
        return (std::hash<HELP::QueryEval::NodeId>{}(h_id.left) << (sizeof(HELP::ll)/2))
               ^ std::hash<HELP::QueryEval::NodeId>{}(h_id.right)
               ^ std::hash<std::set<HELP::ll>>{}(h_id.left_vars_to_forward)
               ^ std::hash<std::set<HELP::ll>>{}(h_id.right_vars_to_forward)
               ^ std::hash<std::set<HELP::QueryEval::PositionMerge>>{}(h_id.join_rules);
    }
};
}

#endif //HELP_CORE_JOINGRAPH_NODES_H
