from abc import ABC, abstractmethod
import tempfile
import logging
import random
import os
import subprocess
import pathlib
import time

from strips_hgn.training_data.generate import _get_state_value_pairs

from strips_hgn.planning.fast_downward_api import (
    get_optimal_actions_using_fd_with_timeout,
)

from strips_hgn.planning import STRIPSProblem, get_strips_problem
from strips_hgn.planning.utils import (
    max_number_of_add_effects,
    max_number_of_preconditions,
)

_log = logging.getLogger(__name__)


class Generator(ABC):
    @abstractmethod
    def initialize_parameter_space(self, fd_planning_time, total_time):
        raise NotImplementedError

    @abstractmethod
    def is_initialized(self):
        raise NotImplementedError

    @abstractmethod
    def sample_problem(self):
        raise NotImplementedError

    @abstractmethod
    def max_receivers(self):
        raise NotImplementedError

    @abstractmethod
    def max_senders(self):
        raise NotImplementedError


class RandomWalkGenerator:
    """ Uses the random walk generator of Patrick Ferber to generate data.
    """

    def __init__(self, generator_script, domain, lower_bound, upper_bound):
        self.script = generator_script
        self.domain = domain
        self._lb = lower_bound
        self._ub = upper_bound

    def initialize_parameter_space(self, fd_planning_time, total_time):
        # For now we do not initialize the parameter space
        raise NotImplementedError

    def is_initialized(self):
        return True

    def call_generator(self, instance, num_problems, prefix):
        call_string = f"./{self.script} {instance} {num_problems} {self._lb} {self._ub} {prefix}"
        exitcode = subprocess.call(call_string, shell=True)
        problems = [
            get_strips_problem(self.domain, f"{prefix}_p{i}.pddl")
            for i in range(num_problems)
        ]
        return problems


class SingleParameterGenerator(Generator):
    """
    Generators that generate instances based on a single parameter.
    """

    def __init__(self, domain_dir, generator_call_string, lower_bound, upper_bound):
        """
        Initializes the generator with the directory of the domain and the call
        string to call the generator.
        """
        self.domain = domain_dir
        self.generator_call_string = generator_call_string
        self.lb = lower_bound
        self.ub = upper_bound
        self.initialized = False
        # Generate problem of upper bound size and use it to compute max_senders
        # and max_receivers
        tf = tempfile.NamedTemporaryFile()
        hardest_problem = self.call_generator(upper_bound, tf.name)
        self.receivers = max_number_of_add_effects([hardest_problem])
        self.senders = max_number_of_preconditions([hardest_problem])
        _log.info(f"Max receivers: {self.receivers} Max Senders: {self.senders}")
        self.max_param = self.lb

    def max_senders(self):
        return self.senders

    def max_receivers(self):
        return self.receivers

    def call_generator(self, num_parameter, instance_name):
        """
        Call the generator with its parameter to create a new instance with
        filename instance_name.
        """
        # TODO assumes that generator call has num_parameter as last argument.
        # Instead we could also just ask the user to pass a lambda function that
        # takes num_parameter as argument and calls the generator to generate an
        # instance
        call_string = self.generator_call_string + f" {num_parameter} > {instance_name}"
        print(call_string)
        exitcode = subprocess.call(call_string, shell=True)
        return get_strips_problem(self.domain, instance_name)

    def binary_search(self, fd_timeout, start_time, end_time):
        state_values_per_problem = dict()
        param_min = self.lb
        param_max = self.ub
        max_parameter_solution = param_min
        while param_min <= param_max and time.time() < end_time:
            current = int((param_min + param_max) / 2)
            # Generate random file name
            tf = tempfile.NamedTemporaryFile(
                dir="problems", prefix="problem_", suffix=".pddl", delete=False
            )
            problem = self.call_generator(current, tf.name)
            _log.info(f"Solving problem of size {current}")
            # By default use FD_TIMEOUT. If we do not have sufficient time use the
            # remaining time we have
            planning_time = min(fd_timeout, end_time - time.time())
            plan = get_optimal_actions_using_fd_with_timeout(
                self.domain, tf.name, planning_time
            )
            _log.info(f"Passed time: {time.time() - start_time}")
            if plan:
                param_min = current + 1
                max_parameter_solution = max(max_parameter_solution, current)
                state_values_per_problem[tf.name] = _get_state_value_pairs(
                    problem, plan
                )

            else:
                param_max = current - 1
        return max_parameter_solution, state_values_per_problem

    def initialize_parameter_space(self, fd_planning_time, total_time):
        """
        Uses binary search to compute the hardest problem that we can solve
        within fd_planning_time with.
        """
        start_time = time.time()
        end_time = start_time + total_time
        max_param, pairs = self.binary_search(fd_planning_time, start_time, end_time)
        self.max_param = max_param
        self.initialized = True
        return pairs

    def sample_problem(self, file_name):
        parameter = random.randint(self.lb, self.max_param)
        return self.call_generator(parameter, file_name)

    def is_initialized(self):
        return self.initialized
