import time
import tempfile
import pprint
import json
import multiprocessing as mp
import ray
import itertools
import os
import errno

from generators.generator import Generator, SingleParameterGenerator
from strips_hgn.training_data.save import _TrainingDataEncoder

from strips_hgn.training_data.generate import (
    _generate_optimal_state_value_pairs_for_problem,
)


@ray.remote
def generate_training_data(generator, fd_planning_time, total_time, problem_dir):
    start_time = time.time()
    end_time = start_time + total_time
    state_values_per_problem = dict()
    parameter = generator.max_param

    # Iteratively generate problems for different parameters
    while time.time() < end_time:
        tf = tempfile.NamedTemporaryFile(
            dir=problem_dir, prefix="problem_", suffix=".pddl", delete=False
        )

        problem = generator.call_generator(parameter, tf.name)

        # ignore problems where the goal is empty
        if len(problem.goals) == 0:
            continue
        sas_name = tf.name + ".sas"
        plan_name = tf.name + ".plan"

        state_value_pairs = _generate_optimal_state_value_pairs_for_problem(
            problem, sas_name=sas_name, plan_name=plan_name
        )
        if len(state_value_pairs) == 0:
            continue
        state_values_per_problem[tf.name] = state_value_pairs
        parameter = generator.max_param if parameter == generator.lb else parameter - 1

    return state_values_per_problem


def blocksworld_generator():
    # Generate problems directory if it does not exist
    try:
        os.makedirs("problems")
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    generator = SingleParameterGenerator(
        domain_dir="4ops/domain.pddl",
        generator_call_string="python gen_blocks.py",
        lower_bound=3,
        upper_bound=18,
    )
    one_minute = 60
    one_hour = 60 * one_minute
    fd_planning_time = 30 * one_minute
    initialize_time = 2 * one_hour
    generation_time = 18 * one_hour

    # Initialize parameter space and store trajectories computed during initialization
    initial_data = generator.initialize_parameter_space(
        fd_planning_time, initialize_time
    )

    process_data = list()
    num_processors = 4
    ray.init()

    for i in range(num_processors):
        process_data.append(
            generate_training_data.remote(
                generator, fd_planning_time, generation_time, "problems"
            )
        )

    data = ray.get(process_data)
    data.append(initial_data)

    with open("data_file.json", "w") as write_file:
        json.dump(
            data, write_file, indent=4, cls=_TrainingDataEncoder,
        )


if __name__ == "__main__":
    blocksworld_generator()
