#ifndef FAST_BACKWARD_PLANNING_TASK_H
#define FAST_BACKWARD_PLANNING_TASK_H

#include <algorithm>
#include <set>
#include <string>
#include <vector>
#include <iostream>
#include <limits.h>
#include "../utils/type_defs.h"
#include "../conj_query_eval/join_order_generator/join_order_generator.h"
#include "../conj_query_eval/join_realizer/join_order_graph.h"
#include "db_propositional_shared.h"

namespace HELP { 

class StripsState : public GroundAtomCol {
public:
    using GroundAtomCol::GroundAtomCol;
};

typedef std::vector<Atom> StripsGoal; //TODO: this should be ground atom col

inline void remove_nullary(StripsGoal &goal, ll true_obj_id) {
    for (auto &sg : goal) {
        sg.make_unary_const(true_obj_id);
    }
}

typedef std::string Parameter;

// TODO: where to move this?
inline ll create_const_pred_if_needed(ll const_id, std::unordered_map<ll,ll> &constant_to_const_pred, ll &p_id_count) {
    if (!constant_to_const_pred.contains(const_id)) {
        constant_to_const_pred.emplace(const_id, p_id_count++);
    }

    return constant_to_const_pred.at(const_id);
}

class Action {
    std::string name;
    std::vector<std::string> var_names; //TODO: par
    std::vector<Atom> pre;
    std::vector<Atom> add;
    std::vector<Atom> del;
    ll cost;

public:
    Action( std::string name,
            std::vector<std::string> var_names, //TODO: par
            std::vector<Atom> pre,
            std::vector<Atom> add,
            std::vector<Atom> del,
            ll cost)
            : name(name),
              var_names(var_names),
              pre(pre),
              add(add),
              del(del),
              cost(cost)
            {}

    ll get_arity() const {
        return var_names.size();
    }

    ll get_cost() const {
        return cost;
    }

    auto &get_name() const {
        return name;
    }

    auto &get_var_names() const {
        return var_names;
    }

    const std::vector<Atom> &get_pre() const {
        return pre;
    }

    void set_pre(std::vector<Atom> &pre) {
        Action::pre = pre;
    }

    const std::vector<Atom> &get_add() const {
        return add;
    }

    const std::vector<Atom> &get_del() const {
        return del;
    }

    const std::string &get_var_name(ll id) const {
        return var_names.at(id);
    }

    bool nullary_atom_exists() { //TODO: maybe want to combine functions just doing something for all atoms
        for (auto &atom: pre) {
            if (atom.is_nullary()) {
                return true;
            }
        }
        for (auto &atom: add) {
            if (atom.is_nullary()) {
                return true;
            }
        }
        for (auto &atom: del) {
            if (atom.is_nullary()) {
                return true;
            }
        }
        return false;
    }

    void make_atoms_unary(ll just_t_id) {
        if (nullary_atom_exists()) {
            ll t_par = var_names.size();
            var_names.push_back("?_t_par");

            pre.push_back({false, just_t_id, {}}); // par will be added by following

            for (auto &atom: pre) {
                atom.make_unary(t_par);
            }
            for (auto &atom: add) {
                atom.make_unary(t_par);
            }
            for (auto &atom: del) {
                atom.make_unary(t_par);
            }
        }
    }

    // optional implementation was available at b0c722fa848ed45a4c302276756bf369e5a42017
    ll create_const_var(ll const_id, std::vector<std::pair<ll,ll>> &const_to_var) {
        const_to_var.emplace_back(const_id, var_names.size());
        var_names.push_back("?artf_var" + std::to_string(var_names.size()));

        return var_names.size() - 1;
    }

    void const_remove(std::vector<Atom> &li, std::unordered_map<ll, ll> &constant_to_const_pred, std::vector<std::pair<ll,ll>> &const_to_var) {
        std::vector<Atom> new_li;

        for (auto &atom : li) {
            std::vector<ParameterOrObject> new_args;
            for (auto &arg : atom.get_args()) {
                if (!arg.is_variable()) {
                    auto v_num = create_const_var(arg.get_index(), const_to_var);
                    new_args.emplace_back(true, v_num);
                } else {
                    new_args.push_back(arg);
                }
            }
            new_li.emplace_back(atom.is_negated(), atom.get_predicate(), new_args);
        }

        li = new_li;
    }

    void eq_remove(std::vector<Atom> &li, std::vector<std::pair<ll,ll>> &equalities) {
        std::vector<Atom> new_li;
        for (auto &atom : li) {
            std::vector<ParameterOrObject> new_args;
            std::unordered_set<ll> arg_s;
            for (auto &arg : atom.get_args()) {
                assert(arg.is_variable());
                auto var = arg.get_index();
                if (arg_s.contains(var)) {
                    //TODO: combine with above (dependent on new_args)
                    new_args.emplace_back(true, var_names.size());
                    var_names.push_back("?artf_var" + std::to_string(var_names.size()));
                } else {
                    new_args.push_back(arg);
                    arg_s.insert(var);
                }
            }
            new_li.emplace_back(atom.is_negated(), atom.get_predicate(), new_args);
        }
        li = new_li;
    }

    void remove_equalities(ll eq_p_id) {
        std::vector<std::pair<ll,ll>> equalities;

        eq_remove(pre, equalities);
        eq_remove(add, equalities);
        eq_remove(del, equalities);

        for (auto &[v1,v2] : equalities) {
            pre.emplace_back(eq_p_id, std::vector<ParameterOrObject>{{true, v1}, {true, v2}});
        }
    }

    void remove_constants_from_atoms(std::unordered_map<ll, ll> &constant_to_const_pred, ll &p_id_count) {
        std::vector<std::pair<ll,ll>> const_to_var;

        const_remove(pre, constant_to_const_pred, const_to_var);
        const_remove(add, constant_to_const_pred, const_to_var);
        const_remove(del, constant_to_const_pred, const_to_var);

        for (auto &[const_id, var] : const_to_var) {
            auto c_predicate = create_const_pred_if_needed(const_id, constant_to_const_pred, p_id_count);
            pre.emplace_back(c_predicate, std::vector<ParameterOrObject>{{true, var}});
        }
    }

    Parameters get_eff_pars(); //TODO: void and write to parameter
};

struct PrintableJoinOrderEl { //TODO move, could auto generate this once parref and are properly wrapped
    const Atom &from;
    const Atom &to;
};

struct PrintableAnnotatedJoinOrderEl { //TODO move, could auto generate this once parref and are properly wrapped
    PrintableJoinOrderEl join_order_el;
    std::vector<const Parameter*> join_pars; //TODO const
    std::vector<const Parameter*> tracked_pars; //TODO const
};


typedef std::vector<Predicate> Predicates;
typedef std::vector<Action> Actions;
typedef std::vector<Object> Objects;

//TODO: get rid of result pars return
inline std::tuple<QueryEval::JoinOrder, QueryEval::Query, Parameters> create_join_order(Action &action, DBInfo &info); //TODO: move decl.
inline std::tuple<QueryEval::AnnotatedJoinOrder, QueryEval::Query, Parameters> create_annoted_join_order(Action &action, DBInfo &info);
inline std::tuple<std::vector<QueryEval::JoinOrderGraph::ExtNodeRepr>, QueryEval::Query, Parameters> create_join_graph(Action &action, DBInfo &info);

struct NumericalVar { //TODO: make this a standard ll wrapper
    ll var_num;
};

//TODO make standard edge
struct VarMapping {
    std::string var;
    std::string obj; //TODO: should be object ref instead string
};

struct VarMap { //TODO: make this a "bracewrapped class"
    std::vector<VarMapping> var_mapping;
};

struct NormalizeReturn {
    ll just_t_id;
};

class LiftedStripsTask {
    Predicates predicates;
    Objects objects;
    Actions actions;

    StripsState initial_state;
    StripsGoal goal;

    // internal //TODO: should be moved to some pretty printer class
    Action *current_print_action; //TODO: could also be a parameter that we pass down while printing
    bool var_as_num = false; // TODO: rm (hacky)
public:
    LiftedStripsTask() {};

    LiftedStripsTask(
            std::vector<Predicate> predicates,
            std::vector<Object> objects,
            std::vector<Action> actions,
            StripsState initial_state,
            StripsGoal goal
        ) : predicates(predicates),
            objects(objects),
            actions(actions),
            initial_state(initial_state),
            goal(goal)
        {}

    //TODO: below should not be public
    //TODO: seems like printing is getting out of hand should be declared somewhere else
    //TODO: think there could be some more const
    void dump_dl_repr(const ObjectRef &obj, std::ostream &outs);
    void dump_dl_repr(Actions &actions, std::ostream &outs); //TODO: actually doesn't need parameter
    void dump_dl_repr(Action &action, std::ostream &outs); //TODO: actually doesn't need parameter
    void dump_dl_repr(const GroundAtom &atom, std::ostream &outs);
    void dump_dl_repr(const Atom &atom, std::ostream &outs);
    void dump_dl_repr(StripsState &state, std::ostream &outs);
    void dump_dl_repr(const ParameterOrObject &arg, std::ostream &outs);
    void dump_dl_repr(const PrintableJoinOrderEl &el, std::ostream &outs);
    void dump_dl_repr(const PrintableAnnotatedJoinOrderEl &el, std::ostream &outs);
    void dump_dl_repr(const Parameter &el, std::ostream &outs);

    void dump_dl_repr(const QueryEval::JoinOrderGraph::ExtNodeRepr &el, std::ostream &outs);
    void dump_dl_repr(const QueryEval::NodeId &el, std::ostream &outs);
    void dump_dl_repr(const QueryEval::MergeNode &el, std::ostream &outs);
    void dump_dl_repr(const QueryEval::ReOrderNode &el, std::ostream &outs);
    void dump_dl_repr(const NumericalVar &var, std::ostream &outs);
    void dump_dl_repr(const QueryEval::PositionForward &position_map, std::ostream &outs);
    void dump_dl_repr(const QueryEval::PositionMerge &position_map, std::ostream &outs);
    void dump_dl_repr(const QueryEval::JoinGraphNode &node, std::ostream &outs);
    void dump_dl_repr(const VarMap &var_map, std::ostream &outs);
    void dump_dl_repr(const VarMapping &var_maping, std::ostream &outs);
    void dump_dl_repr(const QueryEval::InitCollection &init_col, std::ostream &outs);

    template<typename T>
    void dump_dl_repr(const std::vector<T> &v, std::ostream &outs);
    template<typename T>
    void dump_dl_repr(const T *el, std::ostream &outs) {
        dump_dl_repr(*el, outs);
    }

    void query_eval_print(std::ostream &outs, bool show_tables);

    template<typename JoinOrderEl, std::tuple<std::vector<JoinOrderEl>, QueryEval::Query, Parameters> create_join_order(Action &action, DBInfo &info)>
    void _print_join_orders(std::ostream &outs);

    //TODO: to proper signature
    void query_eval_print(QueryEval::JoinOrderGraph &graph, QueryEval::TableCache &cache, Action *action, std::ostream &outs, bool show_tables);

    ll get_max_var_amount();
    ll get_max_predicate_arity();

    const std::string &get_pred_name(ll id);
    const std::string &get_obj_name(ll id);
    const std::string &get_var_name(ll id);

    ll remove_nullary_predicates_and_empty_preconditions();

    std::set<ll> compute_static_predicates();
    std::map<ll, ll> compute_duplicated_objects();
    std::map<ll, ll> compute_init_table_size();

public:
    void print_stats(std::ostream &outs);
    void print_predicates(std::ostream &outs);
    void print_initial_ground_actions(std::ostream &outs) { query_eval_print(outs, false); }
    void print_table_cache(std::ostream &outs) { query_eval_print(outs, true); }

    void print_initial_add_value(std::ostream &outs);

    void print_normal_join_orders(std::ostream &outs) { _print_join_orders<QueryEval::JoinOrderElement, create_join_order>(outs); } //TODO: make this one template arg
    void print_annotated_join_orders(std::ostream &outs) { _print_join_orders<QueryEval::AnnotatedJoinOrderElement, create_annoted_join_order>(outs); } //TODO: make this one template arg
    //TODO: wrap pair type below
    void print_join_order_graph(std::ostream &outs) { _print_join_orders<QueryEval::JoinOrderGraph::ExtNodeRepr, create_join_graph>(outs); } //TODO: make this one template arg
    void combined_query_eval_print(std::ostream &outs);

    void one_goal_only();
    void dump_dl_repr(std::ostream &outs);
    NormalizeReturn normalize();
    void remove_constants_from_atoms();
    void remove_equal_parameter();

    auto &get_actions() {
        return actions;
    }

    auto &get_goal() {
        return goal;
    }

    auto get_predicate_amount() {
        return predicates.size();
    }

    auto &get_initital_state() {
        return initial_state;
    }

    DBInfo get_info() { //TODO important: TODO important -- do not recompute every time
        return {
            static_cast<ll>(objects.size()),
            static_cast<ll>(predicates.size()),
            compute_static_predicates(),
            compute_duplicated_objects(),
            compute_init_table_size()
        };
    }

    void simple_datalog_transformation();
};

}

#include "lifted_strips_task.tpp" //TODO: doesn't seem like a good solution to mix this with .cpp

#endif