#ifndef HELP_CORE_JOIN_ORDER_GENERATOR_H
#define HELP_CORE_JOIN_ORDER_GENERATOR_H

#include "../join_and_project.h"
#include <iostream>

namespace HELP { 

namespace QueryEval {
    class JoinOrderGenerator {
    public:
        virtual JoinOrder generate(Query &query) = 0;
        virtual ~JoinOrderGenerator() = default;
    };

    extern JoinOrderGenerator *join_order_generator;

    inline JoinOrder create_join_order(Query &query) {
        return join_order_generator->generate(query); //TODO: move?
    }

    inline bool doesnt_drop_result_pars(AnnotatedJoinOrder &annotated_order, Parameters &result_pars, Query &query) { //for debugging
        std::vector<std::set<ll>> should_track(query.size(), std::set<ll>());
        std::set<ParRef> result_pars_s(result_pars.begin(), result_pars.end());

        for (auto &order_el : annotated_order) {
            auto &to_track = should_track[order_el.element.from_id];

            std::set<ParRef> &from_pars = query.get_pars(order_el.element.from_id);
            std::set<ParRef> from_res_pars;
            std::set_intersection(from_pars.begin(), from_pars.end(),
                                  result_pars_s.begin(), result_pars_s.end(),
                                  std::inserter(from_res_pars, from_res_pars.end()));

            for (auto &ref : to_track) {
                from_res_pars.insert(ref);
            }

            std::set<ParRef> annotated(order_el.pars_tracked.begin(), order_el.pars_tracked.end());

            if (!std::includes(annotated.begin(), annotated.end(),
                               from_res_pars.begin(), from_res_pars.end())) {
                return false;
            }

            should_track[order_el.element.to_id] = from_res_pars;
        }

        return true;
    }

    //TODO: could combine this with Join Order extraction itself, should do this at some point
    inline AnnotatedJoinOrder annotate(Query &query, JoinOrder &join_order) { //TODO: probably should become a constructor at some point
        Parameters &result_pars = query.get_result_pars();
        AnnotatedJoinOrder annotated_join_order;
        std::set<ParRef> no_longer_needed;
        std::set<ParRef> result_pars_s(result_pars.begin(), result_pars.end());

        std::vector<ll> tracked_amounts; //TODO: size_t?
        for (ll i = 0; i < query.get_par_amount(); i++) {
            tracked_amounts.push_back(query.atoms_for(i).size());
        }


        std::vector<std::set<ParRef>> current_pars;
        for (ll i = 0; i < query.size(); i++) {
            current_pars.push_back(query.get_pars(i));
        }

        for (auto &j_element : join_order) {
            auto &s_from = current_pars[j_element.from_id];

            // update ds by removed pars
            for (auto par : s_from) {
                if (result_pars_s.contains(par)) {
                    continue;
                }
                if (--tracked_amounts[par] == 0) {
                    no_longer_needed.insert(par);
                }
            }

            // determine annotation
            auto &s_to = current_pars[j_element.to_id];

            // join_pars = a intersect b
            std::set<ParRef> join_pars;
            std::set_intersection(s_from.begin(), s_from.end(),
                                  s_to.begin(), s_to.end(),
                                  std::inserter(join_pars, join_pars.end()));

            // tracked_pars = (a cup b) minus (no_longer_needed)
            std::set<ParRef> s_union;
            std::set<ParRef> tracked_pars;

            std::set_union(s_from.begin(), s_from.end(),
                                  s_to.begin(), s_to.end(),
                                  std::inserter(s_union, s_union.end()));
            std::set_difference(s_union.begin(), s_union.end(),
                                  no_longer_needed.begin(), no_longer_needed.end(),
                                  std::inserter(tracked_pars, tracked_pars.end()));

            // new_pars = tracked_pars minus b
            std::set<ParRef> new_pars;
            std::set_difference(tracked_pars.begin(), tracked_pars.end(),
                                s_to.begin(), s_to.end(),
                                std::inserter(new_pars, new_pars.end())); //TODO: vec, back inserter

            for (auto par : new_pars) {
                tracked_amounts[par]++;
            }

            current_pars[j_element.to_id] = tracked_pars;

            // created annotated element
            //TODO: res_pars vs tracked_pars;
            annotated_join_order.push_back({
                j_element,
                {join_pars.begin(), join_pars.end()},
                {tracked_pars.begin(), tracked_pars.end()}//,
                //{result_tracked_pars.begin(), result_tracked_pars.end()}
            });
        }

        assert(doesnt_drop_result_pars(annotated_join_order, result_pars, query));
        //TODO: assert(all_connected);

        return annotated_join_order;
    }

    inline void try_optimize_jo(AnnotatedJoinOrder &jo) {
        //TODO: potential reorder by static load
    }

    inline AnnotatedJoinOrder create_join_order_and_annotate(Query &query) {
        auto join_order = create_join_order(query);
        auto annoted_join_order = annotate(query, join_order);
        try_optimize_jo(annoted_join_order);
        return annoted_join_order;
    }
}

}
#endif //HELP_CORE_JOIN_ORDER_GENERATOR_H
