#include "sampling_engine.h"

#include "../evaluation_context.h"
#include "../heuristic.h"

#include "../algorithms/ordered_set.h"
#include "../task_utils/successor_generator.h"
#include "../task_utils/task_properties.h"
#include "../task_utils/predecessor_generator.h"

#include <cassert>
#include <fstream>
#include <iostream>
#include <set>
#include <memory>
#include <numeric>
#include <stdio.h>
#include <string>


using namespace std;

namespace sampling_engine {
const string SAMPLE_FILE_MAGIC_WORD = "# MAGIC FIRST LINE";
/* Global variable for sampling algorithms to store arbitrary paths (plans [
   operator ids] and trajectories [state ids]).
   (as everybody can access it, be certain who does and how) */
std::vector<Path> paths;

Path::Path(StateID start) {
    trajectory.push_back(start);
}

Path::~Path() { }

void Path::add(OperatorID op, StateID next) {
    plan.push_back(op);
    trajectory.push_back(next);
}

const Plan &Path::get_plan() const {
    return plan;
}

const Trajectory &Path::get_trajectory() const {
    return trajectory;
}


static vector<shared_ptr<sampling_technique::SamplingTechnique>>
prepare_sampling_techniques(
    vector<shared_ptr<sampling_technique::SamplingTechnique>> input) {
    if (input.empty()) {
        input.push_back(make_shared<sampling_technique::TechniqueNull>());
    }
    return input;
}


SampleFormat select_sample_format(const string &sample_format) {
    if (sample_format == "csv") {
        return SampleFormat::CSV;
    } else if (sample_format == "fields") {
        return SampleFormat::FIELDS;
    }
    cerr << "Invalid sample format:" << sample_format << endl;
    utils::exit_with(utils::ExitCode::SEARCH_INPUT_ERROR);
}

StateFormat select_state_format(const string &state_format) {
    if (state_format == "pddl") {
        return StateFormat::PDDL;
    } else if (state_format == "fdr") {
        return StateFormat::FDR;
    }
    cerr << "Invalid state format:" << state_format << endl;
    utils::exit_with(utils::ExitCode::SEARCH_INPUT_ERROR);
}


SamplingEngine::SamplingEngine(const Options &opts)
    : SearchEngine(opts),
      shuffle_sampling_techniques(opts.get<bool>("shuffle_techniques")),
      skip_undefined_facts(opts.get<bool>("skip_undefined_facts")),
      max_sample_cache_size(opts.get<int>("sample_cache_size")),
      iterate_sample_files(opts.get<bool>("iterate_sample_files")),
      index_sample_files(opts.get<int>("index_sample_files")),
      max_sample_files(opts.get<int>("max_sample_files")),
      count_sample_files(0),
      sampling_techniques(
        prepare_sampling_techniques(
            opts.get_list<shared_ptr<sampling_technique::SamplingTechnique>>(
                "techniques"))),
      current_technique(sampling_techniques.begin()),
      sample_format(select_sample_format(opts.get<string>("sample_format"))),
      state_format(select_state_format(opts.get<string>("state_format"))),
      field_separator(opts.get<string>("field_separator")),
      state_separator(opts.get<string>("state_separator")),
      rng(utils::parse_rng_from_options(opts)) {
    if (max_sample_cache_size <= 0) {
        cerr << "sample_cache_size has to be positive: "
             << max_sample_cache_size << endl;
        utils::exit_with(utils::ExitCode::SEARCH_INPUT_ERROR);
    }
    if (max_sample_files != -1 && max_sample_files <= 0) {
        cerr << "max_sample_files has to be positive or -1 for unlimited: "
             << max_sample_files << endl;
        utils::exit_with(utils::ExitCode::SEARCH_INPUT_ERROR);
    }
}

void SamplingEngine::initialize() {
    cout << "Initializing Sampling Engine...";
    remove(plan_manager.get_plan_filename().c_str());
    cout << "done." << endl;
}

void SamplingEngine::update_current_technique() {
    if (shuffle_sampling_techniques) {
        current_technique = sampling_techniques.end();
        int total_remaining = accumulate(
                sampling_techniques.begin(),
                sampling_techniques.end(),
                0,
                [] (
                        int sum,
                        const shared_ptr<sampling_technique::SamplingTechnique> &st) {
                    assert (st->get_count() >= st->get_counter());
                    return sum + (st->get_count() - st->get_counter());});
        if (total_remaining == 0) {
            return;
        }
        int chosen = (*rng)(total_remaining);
        for (auto iter = sampling_techniques.begin(); iter != sampling_techniques.end(); ++iter){
            int remaining = (*iter)->get_count() - (*iter)->get_counter();
            if (chosen < remaining) {
                current_technique = iter;
                break;
            } else {
                chosen -= remaining;
            }
        }
        assert (current_technique != sampling_techniques.end());
    } else {
        while(current_technique != sampling_techniques.end() &&
              (*current_technique)->empty()) {
            current_technique++;
        }
    }
}

SearchStatus SamplingEngine::step() {
    update_current_technique();
    if (current_technique == sampling_techniques.end()) {
        finalize = true;
        save_plan_if_necessary();
        return SOLVED;
    }

    const shared_ptr<AbstractTask> next_task = (*current_technique)->next(task);
    vector<string> new_samples =  sample(next_task);
    sample_cache.push_back(new_samples);
    sample_cache_size += new_samples.size();
    if (sample_cache_size > max_sample_cache_size) {
        save_plan_if_necessary();
    }
    return IN_PROGRESS;
}

string SamplingEngine::sample_file_header() const {
    ostringstream oss;
    oss << "# SampleFormat: " << sample_format << endl;
    if (sample_format == SampleFormat::FIELDS) {
        oss << "# Starts with json format describing the fields, followed by \n"
               "# whitespace followed by all fields separated by "
            << field_separator << ".\n";
    } else if (sample_format == SampleFormat::CSV) {
        oss << "# All fields one after another concatenated by "
            << field_separator << ".\n";
    } else {
        utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
    }
    oss << "# StateFormat: " << state_format << "\n"
        << "# Element in state entry are separated by " << state_separator
        << "(this might have been converted between writing the samples and "
           "and long time storage of the samples.";
    return oss.str();
}

void SamplingEngine::convert_and_push_state(
        ostringstream &oss,
        const State &state) const {

    if (state_format == StateFormat::FDR) {
        const AbstractTask *task = state.get_task().get_task();
        bool first_round = true;
        for (int var = 0; var < task->get_num_variables(); ++var) {
            for (int val = 0; val < task->get_variable_domain_size(var); ++val) {
                if (skip_undefined_facts &&
                        task->is_undefined(FactPair(var, val))) {
                    continue;
                }
                oss << (first_round ? "" : state_separator)
                    << (state.get_values()[var] == val);
                first_round = false;
            }
        }
    } else if (state_format == StateFormat::PDDL) {
        task_properties::dump_pddl(state,oss,
                state_separator, skip_undefined_facts);
    } else {
        utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
    }
}

void SamplingEngine::convert_and_push_goal(
        ostringstream &oss, const AbstractTask &task) const {
    if (state_format == StateFormat::FDR) {
        vector<int> shift;
        shift.push_back(0);
        shift.reserve(task.get_num_variables());
        for (int var = 0; var < task.get_num_variables(); ++var) {
            int nb_undefined = 0;
            if (skip_undefined_facts) {
                for (int val = 0; val < task.get_variable_domain_size(var); ++val) {
                    if (task.is_undefined(FactPair(var, val))) {
                        nb_undefined++;
                    }
                }
            }
            shift.push_back(shift[var] + task.get_variable_domain_size(var) -
                nb_undefined);
        }
        vector<int> goal(shift.back(), 0);
        for (int i = 0; i < task.get_num_goals(); ++i) {
            const FactPair &fp = task.get_goal_fact(i);
            if (skip_undefined_facts && task.is_undefined(fp)) {
                continue;
            }
            int nb_undefined = 0;
            if (skip_undefined_facts) {
                for (int val = 0; val < fp.value; ++val) {
                    if (task.is_undefined(FactPair(fp.var, val))) {
                        nb_undefined++;
                    }
                }
            }
            goal[shift[fp.var] + fp.value - nb_undefined] = 1;
        }
        bool first_round = true;
        for (int i : goal) {
            oss << (first_round ? "" : state_separator) << i;
            first_round = false;
        }
    } else {
        utils::exit_with(utils::ExitCode::SEARCH_CRITICAL_ERROR);
    }
}

void SamplingEngine::print_statistics() const {

    cout << "Generated Samples: " << (generated_samples + sample_cache_size)
         << endl;
    cout << "Sampling Techniques used:" << endl;
    for (auto &st : sampling_techniques) {
        cout << '\t' << st->get_name();
        cout << ":\t" << st->get_counter();
        cout << "/" << st->get_count() << '\n';
    }
}

void SamplingEngine::save_plan_if_necessary() {
    assert(finalize || sample_cache_size >= max_sample_cache_size);
    while ((finalize || sample_cache_size >= max_sample_cache_size) &&
           sample_cache_size > 0 &&
           (max_sample_files == -1 || count_sample_files < max_sample_files)) {
        ofstream outfile(
            plan_manager.get_plan_filename() + 
            ((iterate_sample_files) ? to_string(index_sample_files++) : ""),
            iterate_sample_files ? ios::trunc : ios::app);
        outfile << SAMPLE_FILE_MAGIC_WORD << endl 
                << sample_file_header() << endl;
        size_t nb_samples = 0;
        size_t idx_vector = 0;
        size_t idx_sample = 0;
        for (const vector<string> & samples: sample_cache) {
            idx_sample = 0;
            for (const string & sample : samples){
                outfile << sample << endl;
                if (++nb_samples >= max_sample_cache_size) {
                    break;
                }
                idx_sample++;
            }
            if (nb_samples >= max_sample_cache_size){
                break;
            }
            idx_vector++;
        }
        outfile.close();
        assert(finalize || nb_samples == max_sample_cache_size);
        if (finalize && nb_samples < max_sample_cache_size) {
            idx_vector--;
            idx_sample--;
        }
        //Erase saved samples
        sample_cache[idx_vector].erase(
            sample_cache[idx_vector].begin(),
            sample_cache[idx_vector].begin() + idx_sample + 1);
        sample_cache.erase(
            sample_cache.begin(),
            sample_cache.begin() + idx_vector + 
                ((sample_cache[idx_vector].size() == 0) ? 1 : 0));
        sample_cache_size -= nb_samples;
        count_sample_files++;
    }
}

void SamplingEngine::add_sampling_options(
        OptionParser &parser,
        const string &default_sample_format,
        const string &default_state_format,
        const std::string &default_field_separator,
        const std::string &default_state_separator) {
    parser.add_list_option<shared_ptr < 
        sampling_technique::SamplingTechnique >> (
        "techniques",
        "List of sampling technique definitions to use",
        "[]");
    parser.add_option<bool>("shuffle_techniques",
            "Instead of using one sampling technique after each other,"
            "for each task to generate a sampling technique is randomly chosen with"
            "their probability proportional to the remaining tasks of that technique.",
            "false");

    parser.add_option<int> (
        "sample_cache_size",
        "If more than sample_cache_size samples are cached, then the entries "
         "are written to disk and the cache is emptied. When sampling "
         "finishes, all remaining cached samples are written to disk. If "
         "running out of memory, the current cache is lost.", 
        "5000");
    parser.add_option<bool>(
        "iterate_sample_files",
        "Every time the cache is emptied, the output file name is the "
        "specified file name plus an appended increasing index. If not set, "
        "then every time the cache is full, the samples will be appended to "
        "the output file.",
        "false");
    parser.add_option<int>(
        "index_sample_files",
        "Initial index to append to the sample files written. Works only in "
        "combination with iterate_sample_files",
        "0");
    parser.add_option<int>(
        "max_sample_files",
        "Maximum number of sample files which will be written to disk."
        " After writing that many files, the search will terminate. Use -1 "
        "for unlimited.",
        "-1"
            );

    parser.add_option<string> (
        "sample_format",
        "Format in which to write down the samples. The field order is"
        "V - value, current state, optionally goal condition. Choose from:"
        "csv: writes the data fields separated by 'separator'.",
        default_sample_format
    );

    parser.add_option<string> (
        "state_format",
        "Format in which to write down states. Choose from:"
        "fdr: write states as fdr representation of Fast Downward."
        "All values are separated by the separator.",
        default_state_format
    );

    parser.add_option<string> (
        "field_separator",
        "String sequence used to separate the fields in a sample",
        default_field_separator
    );

    parser.add_option<string> (
        "state_separator",
        "String sequence used to separate the items in a state field",
        default_state_separator
    );
    parser.add_option<bool> (
        "skip_undefined_facts",
        "Does not write down facts representing the variable is undefined",
        "false"
    );
    utils::add_rng_options(parser);
}
}
