import time
import tempfile
import pprint
import json
import multiprocessing as mp
import ray
import itertools
import sys
import os
import errno

from os import listdir
from os.path import isfile, join

from generators.generator import (
    Generator,
    SingleParameterGenerator,
    RandomWalkGenerator,
)
from strips_hgn.training_data.save import _TrainingDataEncoder

from strips_hgn.training_data.generate import (
    _generate_optimal_state_value_pairs_for_problem,
)


@ray.remote
def solve_generated_problem(problem):
    # ignore problems where the goal is empty
    if len(problem.goals) == 0:
        return []
    tf = tempfile.NamedTemporaryFile(dir=".")
    sas_name = tf.name + ".sas"
    plan_name = tf.name + ".plan"

    state_value_pairs = _generate_optimal_state_value_pairs_for_problem(
        problem, sas_name=sas_name, plan_name=plan_name
    )
    if len(state_value_pairs) == 0:
        return []
    return state_value_pairs

def generate_training_problems(generator, problem_dir, output_dir):
    problem_list = [
        join(problem_dir, f)
        for f in listdir(problem_dir)
        if isfile(join(problem_dir, f)) and ".pddl" in f and not "domain" in f
    ]
    if "domain.pddl" in problem_list:
        problem_list.remove("domain.pddl")

    try:
        os.makedirs(output_dir)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    new_problem_dict = dict()
    # for each instance generate new problems
    for instance in problem_list:
        # Generate problems with the generator for this instance
        tf = tempfile.NamedTemporaryFile(dir=output_dir)
        generated_problems = generator.call_generator(instance, num_processors, tf.name)
        new_problem_dict[instance] = generated_problems

    return new_problem_dict

def generate_training_data(generator, fd_planning_time, total_time, problem_dir):
    start_time = time.time()
    end_time = start_time + total_time
    state_values_per_problem = dict()
    new_problem_dict = generate_training_problems(problem_dir, "generated_problems")
    

def generate_training_data2(generator, fd_planning_time, total_time, problem_dir):
    start_time = time.time()
    end_time = start_time + total_time
    state_values_per_problem = dict()

    problem_list = [
        join(problem_dir, f)
        for f in listdir(problem_dir)
        if isfile(join(problem_dir, f)) and ".pddl" in f and not "domain" in f
    ]
    if "domain.pddl" in problem_list:
        problem_list.remove("domain.pddl")

    num_processors = 4
    ray.init()
    num_generated_problems = 4

    # for each instance generate num_processors new problems
    for instance in problem_list:
        # Generate problems with the generator for this instance
        tf = tempfile.NamedTemporaryFile(dir=".")
        generated_problems = generator.call_generator(instance, num_processors, tf.name)
        state_value_pairs = list()
        for i in range(num_processors):
            state_value_pairs.append(
                solve_generated_problem.remote(generated_problems[i])
            )
        state_values_per_problem[instance] = ray.get(state_value_pairs)
        for problem in generated_problems:
            os.remove(problem.problem_pddl)

    return state_values_per_problem


def main(domain):
    problem_dir = "problems"
    script = "run_my_generator.sh"
    # default lower bound and upper bound is 50 and 200
    generator = RandomWalkGenerator(script, domain, 10, 50)

    one_minute = 60
    one_hour = 60 * one_minute
    fd_planning_time = 30 * one_minute
    initialize_time = 2 * one_hour
    generation_time = 18 * one_hour

    data = generate_training_data(
        generator, fd_planning_time, generation_time, problem_dir
    )

    with open("data_file.json", "w") as write_file:
        json.dump(
            data, write_file, indent=4, cls=_TrainingDataEncoder,
        )


if __name__ == "__main__":
    domain = sys.argv[1]
    assert "domain" in domain, "Error: domain file does not contain name domain"
    main(domain)
