import enum
import os
import re
import platform
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))

from tools import parsing as apt

from . import rl_constants



class Ankers(enum.Enum):
    BIAS = "bias_"
    COLD_START = "cold_"
    DROPOUT = "drp_"
    EVALUATOR = "evaluator_"
    EXPAND_GOAL = "expgoal_"
    L2 = "l2_"
    LEARNING_RATE_DECAY = "lr_decay_"
    LOOKAHEAD = "lookahead_"
    LOSS = "loss_"
    MUTEXES = "mutex_"
    NETWORK = "network_"
    OPTIMIZER = "opt_"
    PARTIAL = "partial_"
    REPETITIONS = "x_"
    REPLAY = "replay_"
    REPLAY_BUFFER_FACTOR = "buffer_"
    SAMPLING_TECHNIQUE = "samptech_"
    SAMPLING_TECHNIQUE_ACTIVATE = "A_"
    SAMPLING_TECHNIQUE_DISTRIBUTION = "samptechdist_"
    SAMPLING_TECHNIQUE_SCRAMBLES = "S_"
    SAMPLING_TECHNIQUE_UPGRADE = "UP_"
    SAMPLING_TECHNIQUE_WEIGHT = "W_"
    SCRAMBLES = "scramble_"
    SEARCH = "search_"
    INCLUDE = "INCLUDE_"
    TIME = "time_"



def _get_anker(s, anker, required=True, multiple=False, split_char="-"):
    if s.endswith(".py"):
        s = s[:-3]
    found = [x[len(anker):] for x in s.split(split_char)
             if x.startswith(anker)]
    assert not required or len(found) > 0, "%s: %s" % (anker, str(found))
    if multiple:
        return found
    else:
        assert len(found) <= 1, "%s: %s" % (anker, str(found))
        return None if len(found) == 0 else found[0]


def get_lookahead(s):
    lh = _get_anker(s, Ankers.LOOKAHEAD.value, required=False)
    if lh is None:
        return []
    assert str(int(lh)) == lh
    return ["--lookahead", lh]


def get_search(s):
    se = _get_anker(s, Ankers.SEARCH.value, required=False)
    return [] if se is None else ["--search-engine", se]


def get_sampling_engine(s):
    se = _get_anker(s, Ankers.SEARCH.value, required=False)
    return [] if se is None else ["--sampling-engine", se]


def get_evaluator(s):
    anker = "evaluator_"
    return None if s.find(anker) == -1 else _get_anker(s, anker)


def get_replay(s):
    return _get_anker(s, Ankers.REPLAY.value)


def get_time(s):
    return apt.time(_get_anker(s, Ankers.TIME.value))


def get_partial(s):
    anker = Ankers.PARTIAL.value
    return 0 if s.find(anker) == -1 else apt.int_positive(_get_anker(s, anker))


def get_l2(s):
    anker = Ankers.L2.value
    return 0 if s.find(anker) == -1 else apt.float_interval(min_value=0)(_get_anker(s, anker))


def get_opt(s):
    optimizers = {
        "adam": "adam",
        "sgd": "sgd",
        "sgdM": "sgd(momentum=1.0)"
    }
    anker = Ankers.OPTIMIZER.value
    opt = "adam" if s.find(anker) == -1 else optimizers.get(_get_anker(s, anker), None)
    assert opt is not None, "Invalid optimizer: %s" % s
    return opt


def get_loss(s):
    anker = Ankers.LOSS.value
    return "mse" if s.find(anker) == -1 else _get_anker(s, anker)


def get_dropout(s):
    anker = Ankers.DROPOUT.value
    return 0 if s.find(anker) == -1 else apt.float_interval(min_value=0, max_value=1)(_get_anker(s, anker))


def get_lr_decay(s):
    anker = Ankers.LEARNING_RATE_DECAY.value
    if s.find(anker) == -1:
        return ""

    decay = _get_anker(s, anker)
    if decay == "exp":
        return ",keras_learning_rate_scheduler_exponential(decay_rate=0.99999,decay_step=1)"
    elif decay == "cycle":
        return ",keras_learning_rate_scheduler_cycle(min_lr=0.0001,max_lr=0.01,cycle_length=500)"
    elif decay == "cycle_exp":
        return ",keras_learning_rate_scheduler_cycle(min_lr=0.0001,max_lr=0.01,cycle_length=500,decay_rate=0.99999,decay_step=1)"
    else:
        assert False, decay


def get_buffer_factor(s):
    anker = Ankers.REPLAY_BUFFER_FACTOR.value
    return [] if s.find(anker) == -1 else ["--buffer-size-factor",
                                           _get_anker(s, anker)]

def get_bias(s, anker=Ankers.BIAS.value, split_char="-"):
    """
    bias_{ff|nn}_{probabilistic|prob|max}_{RELOAD_FREQUENCY:int}_{adapt_bias:float}
    nn assumes predefinition of netowrks as hnn
    :param s:
    :return:
    """
    arg = _get_anker(s, anker, required=False, split_char=split_char)
    if arg is None:
        return []

    arg = arg.split("_")
    assert len(arg) in [3,4]

    if arg[1] in ["prob", "probabilistic"]:
        probabilistic = "true"
    elif arg[1] in ["max"]:
        probabilistic = "false"
    else:
        assert False, arg[1]

    reloads = int(arg[2])

    if arg[0] == "ff":
        bias = "hff()"
        reloads = -1
    elif arg[0] == "nn":
        bias = "hnh(hnn)"
    elif arg[0] == "lmcut":
        bias = "hlmcut()"
        reloads = -1
    elif  arg[0] == "ipdb":
        bias = "hipdb(max_time=120s)"
        reloads = -1
    else:
        assert False, arg[0]

    if len(arg) == 3:
        arg.append("-1")
    adapt_bias = float(arg[3])

    return [bias, probabilistic, reloads, adapt_bias]


def get_sampling_technique_distribustion(s):
    anker = Ankers.SAMPLING_TECHNIQUE_DISTRIBUTION.value
    keys = _get_anker(s, anker, required=False, multiple=False)
    if keys is None:
        return []
    keys = [x.strip() for x in keys.split("_") if x.strip() != ""]
    keys = [float(x) for x in keys]
    assert all(x >= 0 for x in keys)
    return [str(x) for x in keys]


def get_sampling_technique_weight(s):
    anker = Ankers.SAMPLING_TECHNIQUE_WEIGHT.value
    weight = _get_anker(s, anker, required=False, multiple=False, split_char="+")
    if weight is None:
        return None
    else:
        assert str(int(weight)) == weight
        return weight


def get_sampling_technique_upgrade(s):
    anker = Ankers.SAMPLING_TECHNIQUE_UPGRADE.value
    upgrade = _get_anker(s, anker, required=False, multiple=False, split_char="+")
    if upgrade is None:
        return None
    else:
        return [x.replace("~", "=") for x in upgrade.split("_")]


def _get_condition(s):
    map_condition = {
        "after": "min_time",
        "before": "max_time",
        "every": "every_x_time"
    }
    map_unit = {
        "s": 1,
        "m": 60,
        "h": 3600,
    }
    m = re.match(r"(after|before|every)(\d+)([smh])", s)
    assert m is not None, s
    return "keras_condition_counter(%s=%i)" % (
        map_condition[m.group(1)], int(m.group(2)) * map_unit[m.group(3)])


def get_sampling_technique_scrambles(s):
    anker = Ankers.SAMPLING_TECHNIQUE_SCRAMBLES.value
    scrambles = _get_anker(s, anker, required=False, multiple=False, split_char="+")
    if scrambles is None:
        return None
    else:
        scrambles = [x.strip() for x in scrambles.split("_")]
        assert all(x != "" for x in scrambles)
        assert len(scrambles) in [1, 3], \
            "Either scramble count or SCRAMBLE_CONDITON_MODIFIER"
        assert str(int(scrambles[0])) == scrambles[0]
        assert len(scrambles) < 3 or re.match(r"([+\-*/])?(\d+(.\d+)?)", scrambles[2])
        arguments = ["++max-scrambles", scrambles[0]]
        if len(scrambles) > 1:
            arguments += ["++increase-scrambling",
                          _get_condition(scrambles[1]),
                          scrambles[2]]
        return arguments


def get_sampling_technique_active(s):
    anker = Ankers.SAMPLING_TECHNIQUE_ACTIVATE.value
    active = _get_anker(s, anker, required=False, multiple=False, split_char="+")
    if active is None:
        return None
    else:
        return _get_condition(active)


SAMPLING_TECHNIQUE_ABBREVIATIONS = {
        "iforward": "iforward_none",
        "uniform": "uniform_none",
        "gbackward": "gbackward_none",
        "gbackwardnb": "gbackward_none_nb"
    }


def get_sampling_technique_prior_2020_04_09(s):
    anker = Ankers.SAMPLING_TECHNIQUE.value
    keys = _get_anker(s, anker, required=False, multiple=False)
    if keys is None:
        return ["gbackward"]

    keys = [x.strip()
            for x
            in keys.split("_")
            if x.strip() != ""]
    assert all(x in SAMPLING_TECHNIQUE_ABBREVIATIONS for x in keys), keys
    assert len(keys) > 0
    return [SAMPLING_TECHNIQUE_ABBREVIATIONS[k] for k in keys]


def get_sampling_technique_since_2020_04_09(s):
    return _get_anker(s, Ankers.SAMPLING_TECHNIQUE.value,
                      required=False, multiple=True)


def get_expand_goal(s):
    anker = Ankers.EXPAND_GOAL.value
    exp_goal = _get_anker(s, anker, required=False,multiple=False)
    if exp_goal is None:
        return []
    else:
        exp_goal = [x.strip() for x in exp_goal.split("_")]
        assert len(exp_goal) in [2, 3]
        assert str(int(exp_goal[0])) == exp_goal[0]
        assert exp_goal[1] in ["false", "true"]
        if len(exp_goal) == 2:
            exp_goal += ["-1"]
        assert str(int(exp_goal[2])) == exp_goal[2]
        return ["--expand-goal"] + exp_goal


def get_mutex_options_arguments(args):
    assert args in [None, "none", "trans", "translator", "h2"]
    if args in [None, "trans", "translator"]:
        return []
    elif args in ["none"]:
        return ["--translator-options",
                " --invariant-generation-max-candidates 0"]
    elif args in ["h2"]:
        return ["--preprocessor", rl_constants.PATH_H2_PREPROCESSOR]
    else:
        assert False


def get_mutex_options(s):
    anker = Ankers.MUTEXES.value
    args = _get_anker(s, anker, required=False, multiple=False)
    return get_mutex_options_arguments(args)


def get_training_repetitions(s):
    anker = Ankers.REPETITIONS.value
    arg = _get_anker(s, anker, required=False, multiple=False)
    return 1 if arg is None else int(arg)


TEMPLATE_NETWORK = \
    "keras_mlp(tparams=ktparams(epochs=99999999,loss={loss},epoch_verbosity=50," \
    "rl_batch_generator={replay},batch=250,callbacks=" \
    "[keras_model_saver({save_condition},step=train_end,flags=[model_saved])" \
    "{lr_decay}]," \
    "optimizer={optimizer})," \
    "y_fields=0,x_fields=1,load=model,store=model,learner_formats=[pb,h5]" \
    ",hidden_layer_size=[{hidden_layers}]," \
    "output_units=-1,batch_normalization=1,l2={l2},dropout={dropout})"

def get_network_attribute(network_size,save_condition="threshold=0.1,every_x_epochs=50", lr_decay=""):
    d_default = 2
    m = re.match("(((?:tiny|small|medium)(?:NoRB)?)|n(\d+)rb(\d+)(?:d(\d+))?)",
                 network_size)
    assert m is not None

    named_size = m.group(2)
    if named_size is not None:
        if network_size.startswith("tiny"):
            n = 64
            d = d_default
            rb = 0
        elif network_size.startswith("small"):
            n = 250
            d = d_default
            rb = 1
        elif network_size.startswith("medium"):
            n = 250
            d = d_default
            rb = 2
        else:
            assert False, "Internal Error"

        if network_size.endswith("NoRB"):
            d += 2 * rb
            rb = 0
    else:
        n, rb, d = int(m.group(3)), int(m.group(4)), d_default if m.group(5) is None else int(m.group(5))
        assert n > 0 and rb >= 0 and d >= 0, "n: {n}, rb: {rb}".format(**locals())

    hidden_layers = ",".join([x for x in [
        ",".join(["{n}".format(n=n) for _ in range(d)]),
        ",".join([
            "keras_residual_block(hidden_layer_count=2,"
            "hidden_layer_size={n})".format(**locals())
            for _ in range(rb)])
        ] if x.strip() != ""])

    return TEMPLATE_NETWORK.format(
        hidden_layers=hidden_layers,
        replay="{replay}",
        l2="{l2}",
        dropout="{dropout}",
        optimizer="{optimizer}",
        save_condition=save_condition,
        lr_decay=lr_decay,
        loss="{loss}"
    )


def get_network(s, save_condition="threshold=0.1,every_x_epochs=50", lr_decay=""):
    network_size = _get_anker(s, Ankers.NETWORK.value)
    return get_network_attribute(network_size, save_condition, lr_decay)



def natural_sort(l):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ]
    return sorted(l, key = alphanum_key)


def get_domain_suite(domain):
    dir_domain = os.path.join(rl_constants.BENCHMARK_REPO, domain)
    return ["%s:%s" % (domain, f) for f in natural_sort(os.listdir(dir_domain))
            if f.endswith(".pddl") and f.find("domain") == -1]


def get_tasks(arg):
    if arg == "storage":
        return get_domain_suite(arg)[15:]
    if os.path.isdir(os.path.join(rl_constants.BENCHMARK_REPO, arg)):
        return get_domain_suite(arg)
    elif arg.startswith("ecai") or arg.startswith("further1"):
        m = re.match(r"^(ecai|further1)_(wo_)?([a-zA-Z0-9+]+_)?((min)|(max)(\d+))?$", arg)
        assert m is not None, arg

        tasks = []
        task_selection = m.group(1)
        assert task_selection in rl_constants.TASK_SELECTIONS
        task_selection = rl_constants.TASK_SELECTIONS[task_selection]
        flag_without = m.group(2) is not None
        flag_min_ecai = m.group(5) is not None
        flag_max = int(m.group(7)) if m.group(7) is not None else None
        chosen_domain = None if m.group(3) is None else m.group(3)[:-1]
        if chosen_domain is not None:
            chosen_domain = chosen_domain.replace("+", "-")
        assert flag_max is None or not flag_min_ecai
        assert chosen_domain is None or chosen_domain in task_selection

        if flag_without:
            domains = [d for d in task_selection.keys() if d != chosen_domain]
        elif chosen_domain is None:
            domains = task_selection.keys()
        else:
            domains = [chosen_domain]

        for domain in domains:
            marked_tasks = task_selection[domain]
            present_tasks = natural_sort(get_domain_suite(domain))
            if marked_tasks is None:
                tasks.extend(present_tasks)
            else:
                marked_tasks = ["%s:%s" % (domain, et) for et in marked_tasks]
                assert all(et in present_tasks for et in marked_tasks), domain
                if flag_min_ecai:
                    for no_pt, pt in enumerate(present_tasks):
                        if pt in marked_tasks:
                            tasks.extend(present_tasks[no_pt:])
                            break
                elif flag_max is None:
                    tasks.extend(marked_tasks)
                else:
                    tasks.extend(marked_tasks[-flag_max:])
        print(tasks)
        return tasks
    else:
        assert False, arg


def get_scrambling(s):
    scrambling = _get_anker(s, Ankers.SCRAMBLES.value,
                            required=False, multiple=False)
    if scrambling is None:
        return []
    scrambling = [x.strip() for x in scrambling.split("_") if x.strip() != ""]
    assert all(x.find("inc") == -1 for x in scrambling)
    assert all(str(int(x)) == x for x in scrambling)
    return scrambling

    #scrambling = scrambling.split("inc")
    #assert len(scrambling) == 1 and s.find("inc") == -1 or \
    #    len(scrambling) == 2 and s.find("inc") > -1
    # return (["--max-scrambles", scrambling[0]] +
    #         ([] if len(scrambling) == 1 else
    #          ["--increase-scrambling",
    #           "keras_condition_counter(flags=inc_scrambling)",
    #           "1", scrambling[1]]))


def get_cold(s):
    cold_start = _get_anker(s, Ankers.COLD_START.value)
    if cold_start is None:
        return []
    else:
        m = re.match(r"([a-zA-Z0-9]+)_(\d+)", cold_start)
        assert m, "format shall be evaluator_seconds. given: |%s|" % cold_start
        return ["--cold-start-evaluator", m.group(1),
                "keras_condition_counter(min_time={})".format(m.group(2)),
                "0.1", "1.0", "0.0"]

