
from downward.cached_revision import CachedRevision
#from rl_fast_downward_experiment import RLCachedRevision
from downward.experiment import (
    FastDownwardExperiment, FastDownwardRun, _DownwardAlgorithm)

from lab.experiment import Run

import json
import os
import re

REGEX_RUNS_DIRS = re.compile(r"runs-\d+-\d+")
REGEX_PROBLEM_IDX = re.compile(r"p(\d+)\.pddl")


class RLEcaiRun(Run):
    def __init__(self, exp, algo, task, domain_task, path_model, path_atoms):
        Run.__init__(self, exp)
        self.algo = algo
        self.task = task
        self.path_model = path_model
        self.path_toms = path_atoms

        self._set_properties()
        self.set_property("domain", domain_task)

        # Linking to instead of copying the PDDL files makes building
        # the experiment twice as fast.
        self.add_resource(
            'domain', self.task.domain_file, 'domain.pddl', symlink=True)
        self.add_resource(
            'problem', self.task.problem_file, 'problem.pddl', symlink=True)
        self.add_resource("model", path_model, "model.pb", symlink=True)
        self.add_resource("atoms", path_atoms, "atoms.json", symlink=True)

        self.add_command(
            'planner',
            ['{' + algo.cached_revision.get_planner_resource_name() + '}'] +
            algo.driver_options + ['{domain}', '{problem}'] +
            algo.component_options,

            soft_stdout_limit=10240, hard_stdout_limit=5 * 10240,
            soft_stderr_limit=10240, hard_stderr_limit=5 * 10240
        )

    def _set_properties(self):
        self.set_property('algorithm', self.algo.name)
        self.set_property('repo', self.algo.cached_revision.repo)
        self.set_property('local_revision', self.algo.cached_revision.local_rev)
        self.set_property('global_revision', self.algo.cached_revision.global_rev)
        self.set_property('revision_summary', self.algo.cached_revision.summary)
        self.set_property('build_options', self.algo.cached_revision.build_options)
        self.set_property('driver_options', self.algo.driver_options)
        self.set_property('component_options', self.algo.component_options)
        self.set_property('path_model', self.path_model)

        for key, value in self.task.properties.items():
            self.set_property(key, value)
        self.set_property('experiment_name', self.experiment.name)

        self.set_property('id', [self.algo.name, self.task.domain, self.task.problem])


class RLEcaiAlgorithm(object):
    def __init__(self, name, cached_revision, model_directory, model_template,
                 model_template_values, driver_options, component_options):
        self.name = name
        self.cached_revision = cached_revision
        assert os.path.isdir(model_directory)
        self.model_template = model_template
        self.model_template_values = model_template_values
        self.model_directory = model_directory
        self.driver_options = driver_options
        self.component_options = component_options


DOMAIN_MAPPING = {
    "blocks": ("blocksworld_ipc", lambda x: x),
    "depot": ("depot_fix_goals", lambda x: "depot_" + x),
    "grid": ("grid_fix_goals", lambda x: "grid_" + x),
    "npuzzle": ("npuzzle_ipc", lambda x: "npuzzle_" + x),
    "pipesworld-notankage": ("pipesworld-notankage_fix_goals", lambda x: "pipes_nt_" + x),
    "rovers": ("rovers", lambda x: "rovers_" + x),
    "scanalyzer-opt11-strips": ("scanalyzer-opt11-strips", lambda x: "scanalyzer" + ("" if x in ["p18", "p19"] else "11") + "_" + x),
    "transport-opt14-strips": ("transport-opt14-strips", lambda x: "transport_" + x),
    "visitall-opt14-strips": ("visitall-opt14-strips", lambda x: "visitall_" + x),
    "storage": ("storage", lambda x: "storage_" + x),
}


def get_auxiliary_data(algo, task, cache):
    domain = os.path.basename(os.path.dirname(os.path.dirname(task.problem_file)))
    problem = os.path.basename(os.path.dirname(task.problem_file))
    test_file = os.path.basename(task.problem_file)
    test_idx = REGEX_PROBLEM_IDX.match(test_file)
    if test_idx is None:
        return None
    test_idx = int(test_idx.group(1))
    key = (domain, problem)
    if key not in cache:
        if domain not in DOMAIN_MAPPING:
            cache[key] = None
            return None
        mdomain, mproblem = DOMAIN_MAPPING[domain]
        mproblem = mproblem(problem)
        dir_models = os.path.join(algo.model_directory, mdomain, mproblem)
        if not os.path.isdir(dir_models):
            cache[key] = None
            return None
        path_atoms = os.path.join(dir_models, "atoms.json")
        if not os.path.isfile(path_atoms):
            cache[key] = None
            return None
        models = []
        for value in algo.model_template_values:
            path_model = os.path.join(dir_models, algo.model_template.format(value))
            if os.path.isfile(path_model):
                models.append(path_model)
        if len(models) == 0:
            cache[key] = None
            return None
        cache[key] = (models, path_atoms)

    data = cache[key]
    if data is None:
        return None

    models, path_atoms = data
    return os.path.join(domain, problem), models[test_idx % len(models)], path_atoms


class RLEcaiExperiment(FastDownwardExperiment):
    def __init__(self, *args, **kwargs):
        FastDownwardExperiment.__init__(self, *args, **kwargs)

    def add_algorithm(self, name, repo, rev, model_directory, model_template,
                      model_template_values,
                      component_options, build_options=None, driver_options=None):
        if not isinstance(name, str):
            logging.critical('Algorithm name must be a string: {}'.format(name))
        if name in self._algorithms:
            logging.critical('Algorithm names must be unique: {}'.format(name))
        build_options = build_options or []
        self._algorithms[name] = RLEcaiAlgorithm(
            name, CachedRevision(repo, rev, build_options),
            model_directory, model_template, model_template_values,
            driver_options, component_options)

    def _add_code(self):
        FastDownwardExperiment._add_code(self)
        # for cached_rev in self._get_unique_cached_revisions():
        #     self.add_resource(
        #         cached_rev.get_rl_resource_name(),
        #         cached_rev.get_cached_path('fast-deepcube.py'),
        #         cached_rev.get_exp_path('fast-deepcube.py'))

    def _add_runs(self):
        cache = {}
        for algo in self._algorithms.values():
            for task in self._get_tasks():
                aux = get_auxiliary_data(algo, task, cache)
                if aux is not None:
                    domain_task, path_model, path_atoms = aux
                    self.add_run(RLEcaiRun(
                        self, algo, task, domain_task,
                        path_model, path_atoms))

