#!/usr/bin/env python3

import tempfile
import random
import time
import pprint
import multiprocessing
import sys
import json
import os
from os import listdir
from os.path import isfile, join

from generator import RandomWalkGenerator
from training_data import _TrainingDataEncoder

from generate import _generate_optimal_state_value_pairs_for_problem


def generateAndSolve(domain, problem, max_time_per_problem, length):
    """ Generate a new problem with the random walk generator and solve it with
    FD.
    """
    start_time = time.time()
    # Initialize new generator
    script = "../run_my_generator.sh"
    generator = RandomWalkGenerator(script, domain, length, length + 1)

    # Create temporary file to get a random filename and create a single problem
    tf = tempfile.NamedTemporaryFile(dir=".")
    generated_problems = generator.call_generator(problem, 1, tf.name)
    # Solve problem with FD
    state_value_pairs = list()
    sas_name = tf.name + ".sas"
    plan_name = tf.name + ".plan"
    state_value_pairs = _generate_optimal_state_value_pairs_for_problem(
        domain,
        generated_problems[0],
        max_time_per_problem,
        sas_name=sas_name,
        plan_name=plan_name,
    )
    time_taken = time.time() - start_time
    # remove generated problem file
    os.remove(generated_problems[0])
    return time_taken, state_value_pairs


def adaptableGeneratorWorker(
    time_budget,
    min_time_per_problem,
    max_time_per_problem,
    domain,
    problems,
    shared_min_length,
    shared_max_length,
    lock,
    name,
):
    """ time_budget: time in seconds this method should take
      min_time_per_problem: minimum time in seconds that an instance should take in order to qualify
            as interesting
      max_time_per_problem: timeout when solving an instance
      domain: single domain for all the problems
      problems: list of problems used for the random walks
      shared_min_length and shared_max_length: dict shared between threads to represent the current
            estimate to the min and max random walk length of good problems
      lock: python lock for updating the shared dicts
      name: the unique name of the worker

  Given the time_budget in seconds, it generates and solves random instances of the different
  problems and keep track of the time needed to solve each of them in order to estimate the lower
  and upper bound in the length of the random walk to obtain interesting problems
  """

    cutoff_time = time.time() + time_budget

    # Log is not really necessary. However it can be handy to keep track of how much time each sample
    # took in order to reduce the generation time after the fact
    log = []

    # Dictionary that stores optimal state-value pairs for good problems
    state_value_pairs_per_problem = {problem: list() for problem in problems}

    while time.time() < cutoff_time:
        # Taking a random problem
        p = random.sample(problems, 1)[0]

        min_length = shared_min_length[p]
        max_length = shared_max_length[p]

        length = int(random.uniform(shared_min_length[p], shared_max_length[p]))
        t, state_value_pairs = generateAndSolve(domain, p, max_time_per_problem, length)
        log.append([p, min_length, max_length, length, t])

        if t >= max_time_per_problem:
            proposed_change = int((shared_max_length[p] + length) / 2)
            if shared_max_length[p] == MIN_LENGTH:
                print(name, "Min and max length are equal. Forcing min length to be 0")
                proposed_change = 0
            elif shared_max_length[p] - MIN_LENGTH < 10:
                print(
                    name,
                    "Min and max length are too close together. Try decreasing MIN_LENGTH",
                )

            with lock:
                if proposed_change < shared_max_length[p]:
                    shared_max_length[p] = proposed_change
            print(name, p, "new max =", shared_max_length[p])

        elif t < min_time_per_problem:
            # Stepping the minimum faster with weighted average
            proposed_change = int((shared_min_length[p] + 3 * length) / 4)

            with lock:
                if proposed_change > shared_min_length[p]:
                    shared_min_length[p] = proposed_change
            print(name, p, "new min =", shared_min_length[p])
        else:
            # Good sample, process it
            print(name, p, "good sample")
            state_value_pairs_per_problem[p].append(state_value_pairs)
            pass

    # Store optimal solutions in json file
    print("Storing optimal solutions.")
    with open(f"data_file_{name}.json", "w") as write_file:
        json.dump(
            state_value_pairs_per_problem,
            write_file,
            indent=4,
            cls=_TrainingDataEncoder,
        )
    print("Worker", name, "done!")
    pprint.pprint(log)
    return log


if __name__ == "__main__":
    manager = multiprocessing.Manager()
    shared_min_length = manager.dict()
    shared_max_length = manager.dict()

    MIN_LENGTH = 50
    MAX_LENGTH = 60

    problem_dir = sys.argv[1]

    problems = [
        join(problem_dir, f)
        for f in listdir(problem_dir)
        if isfile(join(problem_dir, f)) and ".pddl" in f and not "domain" in f
    ]
    domain = join(problem_dir, "domain.pddl")

    for p in problems:
        shared_min_length[p] = MIN_LENGTH
        shared_max_length[p] = MAX_LENGTH

    one_minute = 60
    one_hour = 60 * one_minute

    # total time we generate data for
    time_budget_in_secs = 60 #10 * one_hour
    # Any problem that takes less than this amount of seconds is considered a bad (trivial) problem
    min_time_per_problem = 0 #5 * one_minute
    # Timeout for hard problems
    max_time_per_problem = 10 #30 * one_minute
    # Number of processors to use
    num_workers = 4

    lock = multiprocessing.Lock()

    workers = []
    for i in range(0, num_workers):
        name = "worker-%d" % i
        workers.append(
            multiprocessing.Process(
                target=adaptableGeneratorWorker,
                name=name,
                args=(
                    time_budget_in_secs,
                    min_time_per_problem,
                    max_time_per_problem,
                    domain,
                    problems,
                    shared_min_length,
                    shared_max_length,
                    lock,
                    name,
                ),
            )
        )

    for w in workers:
        w.start()
    for w in workers:
        w.join()

    # Showing the converged length windows
    for k in problems:
        print(k, shared_min_length[k], shared_max_length[k])
