
from downward.cached_revision import CachedRevision
#from rl_fast_downward_experiment import RLCachedRevision
from downward.experiment import (
    FastDownwardExperiment, FastDownwardRun, _DownwardAlgorithm)

from lab.experiment import Run

from collections import defaultdict
import json
import os
import re

REGEX_RUNS_DIRS = re.compile(r"runs-\d+-\d+")
REGEX_TASK_IDX = re.compile(r".*p(\d+)\.pddl$")


class RLRobustnessRun(Run):
    def __init__(self, exp, algo, task,domain_task, path_model, used_atoms,
                 default_atom_values):
        Run.__init__(self, exp)
        self.algo = algo
        self.task = task
        self.domain_task = domain_task
        self.path_model = path_model
        self.used_atoms = used_atoms
        self.default_atom_values = default_atom_values

        self._set_properties()
        self.set_property("domain", domain_task)

        # Linking to instead of copying the PDDL files makes building
        # the experiment twice as fast.
        self.add_resource(
            'domain', self.task.domain_file, 'domain.pddl', symlink=True)
        self.add_resource(
            'problem', self.task.problem_file, 'problem.pddl', symlink=True)
        self.add_resource("model", path_model, "model.pb", symlink=True)
        self.add_new_file("network_atoms", "network_atoms.csv", used_atoms)
        self.add_new_file("network_defaults", "network_defaults.csv", default_atom_values)

        component_options = [
            x.replace("{network_input_atoms}", "[file network_atoms.csv]").replace(
                "{network_input_defaults}", "[file network_defaults.csv]")
            for x in algo.component_options]
        self.add_command(
            'planner',
            ['{' + algo.cached_revision.get_planner_resource_name() + '}'] +
            algo.driver_options + ['{domain}', '{problem}'] + component_options,

            soft_stdout_limit=10240, hard_stdout_limit=5 * 10240,
            soft_stderr_limit=10240, hard_stderr_limit=5 * 10240
        )

    def _set_properties(self):
        self.set_property('algorithm', self.algo.name)
        self.set_property('repo', self.algo.cached_revision.repo)
        self.set_property('local_revision', self.algo.cached_revision.local_rev)
        self.set_property('global_revision', self.algo.cached_revision.global_rev)
        self.set_property('revision_summary', self.algo.cached_revision.summary)
        self.set_property('build_options', self.algo.cached_revision.build_options)
        self.set_property('driver_options', self.algo.driver_options)
        self.set_property('component_options', self.algo.component_options)
        self.set_property('path_model', self.path_model)

        for key, value in self.task.properties.items():
            self.set_property(key, value)
        self.set_property('experiment_name', self.experiment.name)

        self.set_property('id', [self.algo.name, self.domain_task, self.task.problem])


class RLRobustnessAlgorithm(object):
    def __init__(self, name, cached_revision, nn_experiment, driver_options,
                 component_options,max_nb_models=None,
                 skip_first_model=False,
                 sort_by_coverage_ratio=None):
        self.name = name
        self.cached_revision = cached_revision
        if isinstance(nn_experiment, str):
            nn_experiment = [nn_experiment]
        missing = [x for x in nn_experiment if not os.path.isdir(x)]
        assert len(missing) == 0, "Missing: %i, %s" % (len(missing), "\n".join(missing))
        self.nn_experiment = nn_experiment
        self.driver_options = driver_options
        self.component_options = component_options
        self.max_nb_models = max_nb_models
        self.skip_first_model = skip_first_model
        if isinstance(sort_by_coverage_ratio, str):
            sort_by_coverage_ratio = [sort_by_coverage_ratio]
        assert (sort_by_coverage_ratio is None or
                all(os.path.isdir(x) for x in sort_by_coverage_ratio)
                ), sort_by_coverage_ratio
        self.sort_by_coverage_ratio = sort_by_coverage_ratio


def extract_initial_atoms(file_pddl, used_atoms):
    assert os.path.isfile(file_pddl)
    # HACK!
    block_start = "(:"
    init_block_start = "%sinit" % block_start
    with open(file_pddl, "r") as f:
        pddl = f.read().lower()
    init_start = pddl.find(init_block_start)
    init_end = pddl.find(block_start, init_start + len(init_block_start))
    assert init_start > -1, file_pddl
    assert init_end > -1, "(this is not guaranteed...)"
    init = pddl[init_start + len(init_block_start): init_end]

    pattern_atom = re.compile(r"\(([^\)]+)")
    init = [x.strip() for x in pattern_atom.findall(init) if x.strip() != ""]
    init = [x.split() for x in init]
    init = ["Atom %s(%s)" % (x[0], ", ".join(x[1:])) for x in init]

    default_values = []
    for atom in used_atoms:
        contained = atom in init
        if atom.startswith("NegatedAtom "):
            default_values.append(0 if contained else 1)
        else:
            default_values.append(1 if contained else 0)
    return default_values


def find_task2meta(dirs_exp):
    if isinstance(dirs_exp, str):
        dirs_exp = [dirs_exp]
    task2meta = defaultdict(list)
    for dir_exp in dirs_exp:
        assert os.path.isdir(dir_exp)
        runs_dirs = [os.path.join(dir_exp, x) for x in os.listdir(dir_exp)
                     if REGEX_RUNS_DIRS.match(x)]
        for runs_dir in runs_dirs:
            for run_dir in [os.path.join(runs_dir, x) for x in os.listdir(runs_dir)]:
                path_model = os.path.join(run_dir, "model.pb")
                path_static_properties = os.path.join(run_dir, "static-properties")
                path_used_atoms = os.path.join(run_dir, "used_atoms.json")
                path_problem = os.path.join(run_dir, "problem.pddl")
                if (not os.path.isfile(path_model) or
                    not os.path.isfile(path_static_properties) or
                    not os.path.isfile(path_used_atoms)
                ):
                    continue
                with open(path_static_properties, "r") as f:
                    static_properties = json.load(f)
                with open(path_used_atoms, "r") as f:
                    used_atoms = json.load(f)
                key = os.path.join(static_properties["domain"],
                                   static_properties["problem"])
                default_values = extract_initial_atoms(path_problem, used_atoms)
                task2meta[key].append((
                    path_model,
                    ";".join(used_atoms).replace(", ", ","),
                    ";".join(str(x) for x in default_values).replace(", ", ",")))
    return task2meta


def sort_task2meta(task2meta, sort_by_coverage_ratio):
    if isinstance(sort_by_coverage_ratio, str):
        sort_by_coverage_ratio = [sort_by_coverage_ratio]
    ratio = defaultdict(lambda: [0, 0])
    for dir_properties in sort_by_coverage_ratio:
        assert os.path.isdir(dir_properties), dir_properties
        file_properties = os.path.join(dir_properties, "properties")
        assert os.path.isfile(file_properties), file_properties
        with open(file_properties, "r") as f:
            properties = json.load(f)
        for props in properties.values():
            file_model = os.path.abspath(props.get("path_model"))
            coverage = props.get("coverage")
            if file_model is None or coverage is None:
                continue
            ratio[file_model][0] += coverage
            ratio[file_model][1] += 1

    for task, meta in task2meta.items():
        if not any(x[0] in ratio for x in meta):
            # Either something went wrong (hightly unlikely), or this is a task which
            # was unnecessarily trained (no in the medium nor hard task set, but
            # due to its index higher then the smallest required task, it was
            # trained)
            task2meta[task] = None
            continue

        assert all(x[0] in ratio for x in meta), task + "\n".join(x[0] for x in meta if x[0] not in ratio)
        meta = sorted(meta, key=lambda x: ratio[x[0]][0]/float(ratio[x[0]][1]), reverse=True)
        task2meta[task] = meta
        # print("CR",task,  [(ratio[x[0]][0]/float(ratio[x[0]][1]), ratio[x[0]][1]) for x in task2meta[task]])
    return task2meta


def parse_original_domain_task(path):
    path = os.path.dirname(path)
    domain = os.path.basename(os.path.dirname(path))
    task = os.path.basename(path)
    idx_pddl = task.find(".pddl")
    if idx_pddl == -1:
        task = task + ".pddl"
    else:
        task = task[:idx_pddl + len(".pddl")]
    return os.path.join(domain, task)


class RLRobustnessExperiment(FastDownwardExperiment):
    def __init__(self, *args, **kwargs):
        FastDownwardExperiment.__init__(self, *args, **kwargs)

    def add_algorithm(self, name, repo, rev, nn_experiment, component_options,
                      build_options=None, driver_options=None, max_nb_models=None,
                      skip_first_model=False,
                      sort_by_coverage_ratio=None):
        """

        :param name:
        :param repo:
        :param rev:
        :param nn_experiment:
        :param component_options:
        :param build_options:
        :param driver_options:
        :param max_nb_models: limit the number of different models to use during
        evaluation (by default loops over all matching models for an ipc task
        :param sort_by_coverage_ratio: single/list of property file[s] which are used to
        sort the given models by their coverage-ratio
        :return:
        """
        if not isinstance(name, str):
            logging.critical('Algorithm name must be a string: {}'.format(name))
        if name in self._algorithms:
            logging.critical('Algorithm names must be unique: {}'.format(name))
        build_options = build_options or []
        self._algorithms[name] = RLRobustnessAlgorithm(
            name, CachedRevision(repo, rev, build_options),
            nn_experiment, driver_options, component_options,
            max_nb_models=max_nb_models,
            skip_first_model=skip_first_model,
            sort_by_coverage_ratio=sort_by_coverage_ratio
        )

    def _add_code(self):
        FastDownwardExperiment._add_code(self)
        # for cached_rev in self._get_unique_cached_revisions():
        #     self.add_resource(
        #         cached_rev.get_rl_resource_name(),
        #         cached_rev.get_cached_path('fast-deepcube.py'),
        #         cached_rev.get_exp_path('fast-deepcube.py'))

    def _add_runs(self):
        for algo in self._algorithms.values():
            task2meta = find_task2meta(algo.nn_experiment)
            max_nb_models = float("inf") if algo.max_nb_models is None else algo.max_nb_models
            print("MAX NB MODELS", max_nb_models)
            print("NB DIFF TASK To TEST ", len(task2meta))
            if algo.sort_by_coverage_ratio is not None:
                sort_task2meta(task2meta, algo.sort_by_coverage_ratio)

            skipped = set()
            added = set()
            for task in self._get_tasks():
                domain_task = parse_original_domain_task(task.problem_file)
                meta = task2meta.get(domain_task)
                if meta is not None:
                    idx = REGEX_TASK_IDX.match(task.problem_file)
                    assert idx is not None, task.problem_file
                    idx = int(idx.group(1)) % min(max_nb_models, len(meta))
                    path_model, used_atoms, default_atom_values = meta[idx]

                    if algo.skip_first_model and path_model.startswith(algo.nn_experiment[0]):
                        skipped.add(path_model)
                        continue
                    added.add(path_model)

                    self.add_run(RLRobustnessRun(
                        self, algo, task, domain_task,
                        path_model, used_atoms, default_atom_values))
            print("Nb Tasks Skipped", len(skipped), "Added", len(added))
