#ifndef HELP_CORE_QUERY_SORT_H
#define HELP_CORE_QUERY_SORT_H

#include <algorithm>
#include <unordered_set>
#include <unordered_map>
#include "../task_formulations/db_propositional_shared.h"

namespace HELP { 

/*
 * The order of atoms in a query makes a massive difference for the join order generation by a sufficient criterion.
 * If a < b and both a, b fulfill the criterion, "a" will be used for the join.
 */

namespace QueryEval {
    
struct CmpInfoArgs {
    const DBInfo &db;
    const Parameters &res_pars;
};

struct CmpInfo {
    const DBInfo &db;
    const std::unordered_map<ll, ll> &res_par_count; //TODO: make this a lookup
};

struct IdAtom {
    ll id;
    Atom *atom;
};

using AtomComparator = bool (const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info);

inline bool is_not_negated(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
    return atom1.atom->is_negated() < atom2.atom->is_negated();
}

inline bool is_static(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
    return info.db.static_predicates.contains(atom2.atom->get_predicate()) < info.db.static_predicates.contains(atom1.atom->get_predicate());
}

inline bool has_result_par(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) { //TODO: probably unimportant for static, should factor that in
    return info.res_par_count.at(atom1.id) < info.res_par_count.at(atom2.id); //TODO important: don't recompute every time
}

// e.g. equality is "useful" as it will always keep the amounts of rows <= before
inline bool most_duplicated_object(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) { //TODO: maybe we should not do this here, but in join_order_generator over actual joined pars
    return info.db.most_duplicated.at(atom1.atom->get_predicate()) < info.db.most_duplicated.at(atom2.atom->get_predicate());
}

inline bool init_table_size(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
    return info.db.init_table_size.at(atom1.atom->get_predicate()) < info.db.init_table_size.at(atom2.atom->get_predicate());
}

inline bool id_comp(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
    return atom2.id < atom1.id;
}

template<AtomComparator... comparators>
struct multiple_atom_comp;

template<AtomComparator comp, AtomComparator... more_comp>
struct multiple_atom_comp<comp, more_comp...> { //TODO: move to util independant of atom type
    static inline bool do_comp(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
        if (comp(atom1, atom2, info)) {
            return true;
        } else if (comp(atom2, atom1, info)) {
            return false;
        } else {
            return multiple_atom_comp<more_comp...>::do_comp(atom1, atom2, info);
        }
    }
};

template<AtomComparator comp>
struct multiple_atom_comp<comp> {
    static inline bool do_comp(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
        if (comp(atom1, atom2, info)) {
            return true;
        }

        return false;
    }
};

inline bool current_comp(const IdAtom &atom1, const IdAtom &atom2, const CmpInfo &info) {
    return multiple_atom_comp<is_not_negated, is_static, has_result_par, most_duplicated_object, init_table_size, id_comp>::do_comp(atom1, atom2, info);
}

struct queryorder_comp {
    const CmpInfo &info;

    bool operator() (const IdAtom &atom1, const IdAtom &atom2) {
        return current_comp(atom1, atom2, info); //TODO: probably want to be able to switch between different variants at some point
    }

    queryorder_comp(const CmpInfo &info) : info(info) {}
};

inline void _query_sort(std::vector<IdAtom> &atoms, const CmpInfo &info) {
    queryorder_comp comp(info);
    std::sort(atoms.begin(), atoms.end(), comp);
}


inline ll count_result_pars(const Atom &atom, std::unordered_set<ParRef> &res_pars) {
    std::unordered_set<ParRef> found;

    for (auto &arg : atom.get_args()) {
        if (arg.is_variable()) { //TODO: should use custom iterator for that
            if (res_pars.contains(arg.get_index())) {
                found.insert(arg.get_index());
            }
        }
    }

    return found.size();
}

//TODO: clean up these wrappers
inline std::vector<Atom> &query_sort(std::vector<Atom> &_atoms, CmpInfoArgs info) { //TODO: rn _atoms?
    std::vector<IdAtom> id_atoms;
    ll id = 0;
    for (auto &atom : _atoms) {
        id_atoms.push_back({id++, &atom});
    }

    std::unordered_map<ll, ll> res_par_count; //TODO: proper types; & could either map atom pointer or be a vec
    std::unordered_set<ParRef> res_lookup(info.res_pars.begin(), info.res_pars.end());
    for (auto &atom : id_atoms) {
        res_par_count.emplace(atom.id, count_result_pars(*atom.atom, res_lookup));
    }

    CmpInfo cmp_info{info.db, res_par_count};
    _query_sort(id_atoms, cmp_info);


    std::vector<Atom> atoms;
    for (auto &atom : id_atoms) {
        atoms.push_back(*atom.atom);
    }
    _atoms = atoms;
    return _atoms;
}

}

}

#endif //HELP_CORE_QUERY_SORT_H
