from collections import defaultdict
import glob
import json
import os
import sys

if sys.version_info < (3,):
    import subprocess32 as subprocess
else:
    import subprocess

import tools

from downward.cached_revision import CachedRevision
from downward.experiment import (
    FastDownwardExperiment, FastDownwardRun, _DownwardAlgorithm)

from lab.experiment import Run


class RLCachedRevision(CachedRevision):
    def get_rl_resource_name(self):
        return 'fast_deepcube_' + self._hashed_name

    def _cleanup(self):
        # Only keep the bin directories in "builds" dir.
        for path in glob.glob(os.path.join(self.path, "builds", "*", "*")):
            if os.path.basename(path) != 'bin':
                if os.path.isdir(path):
                    tools.paths.remove_dir(path)
                else:
                    tools.paths.remove(path)

        # Remove unneeded files.
        tools.paths.remove(self.get_cached_path('build.py'))

        # Strip binaries.
        binaries = []
        for path in glob.glob(os.path.join(
                self.path, "builds", "*", "bin", "*")):
            if os.path.basename(path) in ['downward', 'preprocess']:
                binaries.append(path)
        subprocess.call(['strip'] + binaries)


class RLRun(Run):
    def __init__(self, exp, algo, task, repetition):
        Run.__init__(self, exp)
        self.algo = algo
        self.task = task
        self.repetition = repetition

        self._set_properties()

        # Linking to instead of copying the PDDL files makes building
        # the experiment twice as fast.
        self.add_resource(
            'domain', self.task.domain_file, 'domain.pddl', symlink=True)
        self.add_resource(
            'problem', self.task.problem_file, 'problem.pddl', symlink=True)

        self.add_command(
            'planner',
            ['{' + algo.cached_revision.get_rl_resource_name() + '}'] +
            [algo.network, '{problem}', "--domain-pddl", '{domain}'] +
            algo.component_options,
            soft_stdout_limit=10240, hard_stdout_limit=5 * 10240,
            soft_stderr_limit=10240, hard_stderr_limit=5 * 10240)

    def _set_properties(self):
        self.set_property('algorithm', self.algo.name)
        self.set_property('repo', self.algo.cached_revision.repo)
        self.set_property('local_revision', self.algo.cached_revision.local_rev)
        self.set_property('global_revision',
                          self.algo.cached_revision.global_rev)
        self.set_property('revision_summary', self.algo.cached_revision.summary)
        self.set_property('build_options',
                          self.algo.cached_revision.build_options)
        self.set_property('component_options', self.algo.component_options)
        self.set_property('network', self.algo.network)
        for key, value in self.task.properties.items():
            self.set_property(key, value)

        self.set_property('experiment_name', self.experiment.name)
        self.set_property('repetition', self.repetition)
        self.set_property('id', [self.algo.name, self.task.domain,
                                 self.task.problem, str(self.repetition)])


class RLAlgorithm(object):
    def __init__(self, name, cached_revision, network, component_options,
                 repetitions, filter_previous_robustness=None):
        self.name = name
        self.cached_revision = cached_revision
        self.network = network
        self.component_options = component_options
        self.driver_options = []
        self.repetitions = repetitions
        assert filter_previous_robustness is None or len(filter_previous_robustness) == 2, filter_previous_robustness
        if filter_previous_robustness is None or filter_previous_robustness[0] is None:
            self.prev_robustness_props = None
            self.prev_robustness_threshold = None
        else:
            assert len(filter_previous_robustness) == 2, filter_previous_robustness
            self.prev_robustness_props = filter_previous_robustness[0]
            self.prev_robustness_threshold = filter_previous_robustness[1]
            assert os.path.isfile(self.prev_robustness_props), self.prev_robustness_props


class RLFastDownwardExperiment(FastDownwardExperiment):
    def __init__(self, *args, **kwargs):
        FastDownwardExperiment.__init__(self, *args, **kwargs)

    def add_algorithm(self, name, repo, rev, network, component_options,
                      build_options=None, driver_options=None, repetitions=1,
                      filter_previous_robustness=None):
        assert driver_options is None
        if not isinstance(name, str):
            logging.critical('Algorithm name must be a string: {}'.format(name))
        if name in self._algorithms:
            logging.critical('Algorithm names must be unique: {}'.format(name))
        build_options = build_options or []
        self._algorithms[name] = RLAlgorithm(
            name, RLCachedRevision(repo, rev, build_options),
            network, component_options, repetitions, filter_previous_robustness)

    def _add_code(self):
        FastDownwardExperiment._add_code(self)
        for cached_rev in self._get_unique_cached_revisions():
            self.add_resource(
                cached_rev.get_rl_resource_name(),
                cached_rev.get_cached_path('fast-deepcube.py'),
                cached_rev.get_exp_path('fast-deepcube.py'))

    def _add_runs(self):
        for algo in self._algorithms.values():
            meta = load_robustness_props(algo)
            for task in self._get_tasks():
                if meta is not None:
                    domain_task = parse_original_domain_task(task.problem_file)
                    if domain_task not in meta:
                        continue
                    cov, total = meta[domain_task]
                    if float(cov)/total >= algo.prev_robustness_threshold:
                        continue
                for r in range(algo.repetitions):
                    self.add_run(RLRun(self, algo, task, r))


def parse_original_domain_task(path):
    task = os.path.basename(path)
    domain = os.path.basename(os.path.dirname(path))
    idx_pddl = task.find(".pddl")
    if idx_pddl == -1:
        task = task + ".pddl"
    else:
        task = task[:idx_pddl + len(".pddl")]
    return os.path.join(domain, task)


def load_robustness_props(algo):
    if algo.prev_robustness_props is None:
        return None
    meta = defaultdict(lambda :[0, 0])
    with open(algo.prev_robustness_props, "r") as f:
        properties = json.load(f)
        for props in properties.values():
            domain_task = props["domain"]
            coverage = props["coverage"]
            meta[domain_task][0] += coverage
            meta[domain_task][1] += 1
    return meta


