from __future__ import print_function

from src.training.samplers import DirectorySampler

from .constants import SampleTypes

import json
import os
import psutil
import re
import string
import sys
import time

if sys.version_info < (3,):
    import subprocess32 as subprocess


    def decoder(s):
        return s.decode()
else:
    import subprocess


    def decoder(s):
        return s


def static_var(name, value):
    def decorate(func):
        setattr(func, name, value)
        return func
    return decorate


def timing(old_time, msg):
    new_time = time.time()
    print(msg % (new_time - old_time))
    return new_time


def sort_nicely(l, sort_key=None):
    """
    Taken (5.11.2018) from:
    https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/

    Sort the given list in the way that humans expect.
    :param l: list of strings to sort
    :param sort_key: callable to extract key by which elements in l shall be
    sorted
    :return: None
    """
    sort_key = (lambda x: x) if sort_key is None else sort_key

    def convert(text):
        return int(text) if text.isdigit() else text

    def alphanum_key(key):
        key = sort_key(key)
        return [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)


def find_relevant_tasks(directory, directory_filters, problem_filters,
                        min_depth, max_depth):

    ds = DirectorySampler([],
                          root=directory,
                          filter_dir=directory_filters,
                          filter_file=problem_filters,
                          selection_depth=min_depth,
                          max_depth=max_depth)
    return ds.iterable


def get_common_prefix_suffix_regex(*elements):
    """
    Returns a regex of the form 'CommonPrefix(Uncommon1|UC2|...|)CommonSuffix
    :param elements: list of strings
    :return: Regex matching those strings
    """
    if all(elements[0] == x for x in elements):
        return elements[0]

    common_prefix = os.path.commonprefix(elements)
    common_suffix = os.path.commonprefix([x[::-1] for x in elements])[::-1]
    idx_prefix_end = len(common_prefix)
    idx_suffix_start = len(common_suffix)

    uncommon = "(%s)" % "|".join(
        elem[idx_prefix_end:len(elem) - idx_suffix_start]for elem in elements)

    return common_prefix + uncommon + common_suffix


@static_var("cache", {})
def get_data_stats(dir_data, basename_data):
    if dir_data not in get_data_stats.cache:
        _file_stats = os.path.join(dir_data, "data_set_sizes.json")
        if os.path.exists(_file_stats):
            with open(_file_stats, "r") as f:
                get_data_stats.cache[dir_data] = json.load(f)
    return (get_data_stats.cache[dir_data][basename_data]
            if dir_data in get_data_stats.cache and
            basename_data in get_data_stats.cache[dir_data]
            else None)


KEY_NB_SAMPLES = "#samples"
KEY_NB_PROBLEMS = "#problems"


def get_upper_sample_bound_from_data_stats(
        stats, sample_type, samples_per_problem):
    # ASSUMPTION: all samples are from a solution trajectory! If not the
    # sample type 'plan' has to be managed differently
    if sample_type is None or sample_type == SampleTypes.plan:
        return stats[KEY_NB_SAMPLES] if samples_per_problem is None else min(
            stats[KEY_NB_PROBLEMS] * samples_per_problem, stats[KEY_NB_SAMPLES])
    elif sample_type == SampleTypes.inter:
        return ((stats[KEY_NB_SAMPLES] - stats[KEY_NB_PROBLEMS])
                if samples_per_problem is None else min(
            stats[KEY_NB_PROBLEMS] * samples_per_problem,
            stats[KEY_NB_SAMPLES] - stats[KEY_NB_PROBLEMS]))
    elif sample_type == SampleTypes.init:
        return stats[KEY_NB_PROBLEMS]
    else:
        assert False


def translate(file_task, file_domain=None, file_fast_downward=None, build=None,
              translator_options=None):
    if file_fast_downward is None:
        file_fast_downward = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), "fast-downward.py")
    build = [] if build is None else ["--build=%s" % build]
    command = [file_fast_downward] + build + ["--translate", file_task]
    if file_domain is not None:
        command[2 + len(build):2 + len(build)] = [file_domain]
    if translator_options is not None:
        command += ["--translate-options"] + translator_options
    subprocess.call(command)


def kill_process_and_children(pid):
    try:
        proc = psutil.Process(pid)
        for proc_child in proc.children(recursive=True):
            try:
                proc_child.kill()
            except psutil.NoSuchProcess:
                pass
        proc.kill()
    except (psutil.AccessDenied, psutil.NoSuchProcess):
        pass


def run_fast_downward(cmd):
    try:
        out = subprocess.check_output(cmd)
    except subprocess.CalledProcessError as e:
        out = e.output
        if e.returncode not in [10, 11, 12, 20, 21, 22, 23, 24, 33, 247]:
            print(e.output, file=sys.stderr)
            sys.stderr.flush()
            raise
    return out


class DefaultFormatDict(dict):
    def __missing__(self, key):
        if key.count("|") == 1:
            pkey, default = [x.strip() for x in key.split("|")]
            if len(pkey) == 0:
                raise KeyError(key)
            if pkey in self.keys():
                return self[pkey]
            else:
                return default
        else:
            raise KeyError(key)


def format_with_defaults(template, args, kwargs):
    return string.Formatter().vformat(template, args, DefaultFormatDict(kwargs))


if __name__ == "__main__":
    print("Thou shall not call me directly.")
    assert False
