#ifndef SEARCH_HELP_HEURISTIC_H
#define SEARCH_HELP_HEURISTIC_H

#include "heuristic.h"
#include "../HeLP-core/src/heuristics/heuristic.h"

template<typename HHeuristic> //TODO: make sure HELP::Heuristicc
class HeLPHeuristic : public Heuristic {
    HELP::LiftedStripsTask task;
    HHeuristic heuristic;
    HELP::ll just_t_id;

    inline void add_state_repr(std::vector<HELP::GroundAtom> &repr, const DBState &state, size_t pred_am, bool add_nullary_obj=false) {
        // add normal atoms
        for (auto &atom_container : state.get_relations()) {
            for (auto &ref_objects : atom_container.tuples) {
                std::vector<HELP::ObjectRef> objects;
                for (auto obj : ref_objects) {
                    objects.push_back(obj);
                }
                repr.emplace_back(atom_container.predicate_symbol, objects);
            }
        }

        // add nullary atoms
        auto &nullaries = state.get_nullary_atoms();
        auto objects_in_nullary = add_nullary_obj ? std::vector<HELP::ObjectRef>{just_t_id} : std::vector<HELP::ObjectRef>{};
        for (size_t i = 0; i < nullaries.size(); i++) {
            if (nullaries[i]) {
                repr.emplace_back(i, objects_in_nullary);
            }
        }
    }

    HELP::LiftedStripsTask task_adapter(const Task &task) {
        // construct task
        std::vector<HELP::Predicate> preds;
        for (auto &pred : task.predicates) {
            preds.emplace_back(pred.get_name(), pred.getArity());
        }

        HELP::Objects objs;
        for (auto &obj : task.objects) {
            objs.emplace_back(obj.get_name());
        }

        std::vector<HELP::Action> actions;
        for (auto &action : task.get_action_schemas()) {
            std::vector<std::string> var_names;
            for (auto &par : action.get_parameters()) {
                var_names.push_back(par.name);
            }

            std::vector<HELP::Atom> pre;
            for (auto &atom : action.get_precondition()) {
                if (atom.is_negated()) { //TODO: should this generally be the case?
                    continue;
                }
                std::vector<HELP::ParameterOrObject> args;
                for (auto &arg : atom.get_arguments()) {
                    args.emplace_back(!arg.is_constant(), arg.get_index());
                }
                pre.emplace_back(atom.is_negated(), atom.get_predicate_symbol_idx(), args);
            }
            for (size_t p_id = 0; p_id < action.get_positive_nullary_precond().size(); p_id++) {
                if (action.get_positive_nullary_precond()[p_id]) pre.emplace_back(false, p_id, std::vector<HELP::ParameterOrObject>{});
            }
            for (size_t p_id = 0; p_id < action.get_negative_nullary_precond().size(); p_id++) {
                if (action.get_negative_nullary_precond()[p_id]) pre.emplace_back(true, p_id, std::vector<HELP::ParameterOrObject>{});
            }

            std::vector<HELP::Atom> add;
            std::vector<HELP::Atom> del;
            for (auto &atom : action.get_effects()) {
                std::vector<HELP::ParameterOrObject> args;
                for (auto &arg : atom.get_arguments()) {
                    args.emplace_back(!arg.is_constant(), arg.get_index());
                }
                if (atom.is_negated()) {
                    del.emplace_back(false, atom.get_predicate_symbol_idx(), args);
                } else {
                    add.emplace_back(false, atom.get_predicate_symbol_idx(), args);
                }
            }
            for (size_t p_id = 0; p_id < action.get_positive_nullary_effects().size(); p_id++) {
                if (action.get_positive_nullary_effects()[p_id]) add.emplace_back(false, p_id, std::vector<HELP::ParameterOrObject>{});
            }
            for (size_t p_id = 0; p_id < action.get_negative_nullary_effects().size(); p_id++) {
                if (action.get_negative_nullary_effects()[p_id]) del.emplace_back(false, p_id, std::vector<HELP::ParameterOrObject>{});
            }

            actions.emplace_back(action.get_name(),
                                 var_names,
                                 pre,
                                 add,
                                 del,
                                 action.get_cost());
        }

        std::vector<HELP::GroundAtom> i_state_list;
        add_state_repr(i_state_list, task.initial_state, preds.size());
        add_state_repr(i_state_list, task.static_info, preds.size());

        HELP::StripsState initial_state(i_state_list, preds.size());

        HELP::StripsGoal goal;
        for (auto &sg : task.get_goal().goal) {
            std::vector<HELP::ParameterOrObject> args;
            for (auto obj : sg.get_arguments()) {
                args.emplace_back(false, obj);
            }
            goal.emplace_back(sg.is_negated(), sg.get_predicate_index(), args);
        }
        for (auto p_id : task.get_goal().positive_nullary_goals) {
            goal.emplace_back(false, p_id, std::vector<HELP::ParameterOrObject>{});
        }
        for (auto p_id : task.get_goal().negative_nullary_goals) {
            goal.emplace_back(false, p_id, std::vector<HELP::ParameterOrObject>{});
        }

        HELP::LiftedStripsTask r_task(
            preds,
            objs,
            actions,
            initial_state,
            goal
        );

        // perform task transformations
        auto ret = r_task.normalize();
        r_task.simple_datalog_transformation(); //TODO: this is again only needed for regr add

        just_t_id = ret.just_t_id;

        return r_task;
    }

    HELP::StripsState state_adapter(const DBState &state) { //TODO: it would be very cool to just adjust this by the delta to previous
        std::vector<HELP::GroundAtom> repr;
        add_state_repr(repr, state, task.get_predicate_amount(), true);

        return HELP::StripsState(repr, task.get_predicate_amount());
    }

    int compute_heuristic(HELP::StripsState &state) {
        HELP::ll h = heuristic.compute(state);
        if (std::numeric_limits<HELP::ll>::max() == h) {
            return std::numeric_limits<int>::max();
        }
        return h;
    }

public:
    HeLPHeuristic(const Task &task) : task(task_adapter(task)), heuristic(HeLPHeuristic::task) {
        //TODO: this is a hack to fill in statics in regradd
        compute_heuristic(HeLPHeuristic::task.get_initital_state());
    }

    int compute_heuristic(const DBState &s, const Task &task) override {
        HELP::StripsState s_adapted(state_adapter(s));
        HELP::ll h_val = compute_heuristic(s_adapted);
        if (h_val == std::numeric_limits<HELP::ll>::max()) {
            return UNSOLVABLE_STATE;
        }
        return h_val;
    }
};

#endif  // SEARCH_HELP_HEURISTIC_H
