from __future__ import print_function

from tools.argparse_types import *

from src.training import parser, parser_tools
from src.training.learners import Learner
from src.training.learners.keras_networks import KerasNetwork
from src.training.learners.keras_networks.keras_callbacks import \
    BaseKerasCallback
from src.training.misc import StreamDefinition

import numpy as np
import os
import random
import re
import sys
if sys.version_info < (3,):
    import subprocess32 as subprocess


    def decoder(s):
        return s.decode()
else:
    import subprocess


    def decoder(s):
        return s

# Regexes for slurm
REGEX_JOB_LIST = re.compile(r"(\d+)(_\[([^\]])+\])?")
PATTERN_DEPENDENCY_KEYS = r"user|u|partition|p"
PATTERN_DEPENDENCY_VALUE = r"[^};]+"

REGEX_DEPENDENCY_SLURM = re.compile(r"{((%s):(%s)(;(%s):(%s))*)}" % (
    PATTERN_DEPENDENCY_KEYS, PATTERN_DEPENDENCY_VALUE,
    PATTERN_DEPENDENCY_KEYS, PATTERN_DEPENDENCY_VALUE))

# Regexes for StreamDefinitions
REGEX_STREAM_FIRST_PROBLEM = re.compile(r"{FirstProblem(:(-)?(\d+))?}")
REGEX_STREAM_TEMPORARY_FOLDER = re.compile(r"{TMPDIR}")


class SAS(object):
    REGEX_VARIABLE = re.compile(r"""\s*begin_variable
\s*[a-zA-Z0-9]+
\s*-?\d+
\s*(\d+)
\s*((<none of those>|(Negated)?Atom [^\n]+)\n)+\s*end_variable""", re.MULTILINE)

    REGEX_MUTEX = re.compile(r"""\s*begin_mutex_group
\s*(\d+)
\s*((\d+\s+\d+\n)+)\s*end_mutex_group""")


    def __init__(self, file):
        self._file = file
        assert os.path.isfile(file)
        with open(self._file, "r") as f:
            self._content = f.read()

        self._domains = None
        self._domain_intervals = None
        self._nb_strips_atoms = None

        # Single Mutexes Group: Tuple of Var-Val Tuples
        # (((var1 val1), (var2, val2), ...), ...)
        self._mutexes = None
        self._mutex_relations = None

    @property
    def domains(self):
        if self._domains is None:
            self._domains = [int(x[0]) for x in
                             self.REGEX_VARIABLE.findall(self._content)]
            assert all(x > 0 for x in self._domains)
        return self._domains

    @property
    def mutexes(self):
        if self._mutexes is None:
            m = self.REGEX_MUTEX.findall(self._content)
            self._mutexes = tuple([
                tuple([tuple([int(vv) for vv in var_val.split(" ")])
                       for var_val in mutex_group[1].split("\n")
                       if var_val.strip() != ""])
                for mutex_group in m])

            self._mutex_relations = [
                [set() for _val in range(self.domains[_var])]
                for _var in range(len(self.domains))]
            for mutex_group in self._mutexes:
                for var_val1 in mutex_group:
                    for var_val2 in mutex_group:
                        if var_val1 == var_val2:
                            continue
                        self._mutex_relations[var_val1[0]][var_val1[1]].add(var_val2)
                        self._mutex_relations[var_val2[0]][var_val2[1]].add(var_val1)
        return self._mutexes

    def are_mutex(self, var_val1, var_val2):
        self.mutexes
        return var_val2 in self._mutex_relations[var_val1[0]][var_val1[1]]

    def has_mutex(self, var_val, state):
        for var2, val2 in enumerate(state):
            if var_val[0] == var2 or val2 == -1:
                continue
            if self.are_mutex(var_val, (var2, val2)):
                return True
        return False

    @property
    def domain_intervals(self):
        if self._domain_intervals is None:
            self._domain_intervals = [0]
            for x in self.domains:
                self._domain_intervals.append(self._domain_intervals[-1] + x)
        return self._domain_intervals

    @property
    def nb_strips_atoms(self):
        if self._nb_strips_atoms is None:
            self._nb_strips_atoms = sum(self.domains)
        return self._nb_strips_atoms

    def complete_sas_states(self, states):
        unset_variables = states == -1
        for state, unset_vars in zip(states, unset_variables):
            unset_vars = np.where(unset_vars == 1)[0]
            while True:  # Attempt to complete state
                np.random.shuffle(unset_vars)
                invalid = False
                for v in unset_vars:
                    candidate_values = [x for x in range(self.domains[v])
                                        if not self.has_mutex((v, x), state)]
                    if len(candidate_values) == 0:
                        invalid = True
                        break
                    else:
                        state[v] = np.random.choice(candidate_values)
                if invalid:
                    for v in unset_vars:
                        state[v] = -1
                else:
                    break
        return states

    def random_state(self):
        while True:
            state = [random.randint(0, domain_size - 1)
                     for domain_size in self.domains]
            if all([sum([state[var] == value for var, value in m]) <= 1
                    for m in self.mutexes]):
                return state

    def convert_strips_encodings_to_sas(self, states):
        one_dimensional = False
        if len(states.shape) == 1:
            one_dimensional = True
            states = np.expand_dims(states, 0)
        new_states = np.ndarray(shape=(len(states), len(self.domains)), dtype=int)
        new_states.fill(-1)
        for v in range(len(self.domains)):
            idx_state, val = np.where(states[:, self.domain_intervals[v]:
                                                self.domain_intervals[v + 1]])
            new_states[idx_state, v] = val
        return new_states[0] if one_dimensional else new_states

    def convert_sas_encodings_to_strip(self, states):
        one_dimensional = False
        if len(states.shape) == 1:
            one_dimensional = True
            states = np.expand_dims(states, 0)
        new_states = np.zeros(shape=(len(states), self.nb_strips_atoms))
        for v in range(len(self.domains)):
            idx_assigned = np.where(states[:, v] != -1)[0]
            new_states[idx_assigned, states[idx_assigned, v] + self.domain_intervals[v]] = 1
        return new_states[0] if one_dimensional else new_states


PATTERN_TENSORFLOW_WARNING = re.compile(
    r"\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d\.\d+: "
    r"I tensorflow/core/platform/cpu_feature_guard.cc:\d+] Your CPU supports"
    r" instructions that this TensorFlow binary was not compiled to use: "
    r"[^\n]+\n")
PATTERN_USING_TENSORFLOW = re.compile(r"Using TensorFlow backend.\n")
PATTERN_HDF5_CONVERSION_WARNING = re.compile(
    r"/\S+/__init__\.py:"
    r"36: FutureWarning: Conversion of the second argument of issubdtype from "
    r"`float` to `np\.floating` is deprecated\. In future, it will be treated as "
    r"`np\.float64 == np\.dtype\(float\)\.type`\.\n  from \._conv import register_"
    r"converters as _register_converters\n", re.MULTILINE)





def regex_format_string(s, regex_formatter_pairs):
    for (pattern, modifier) in regex_formatter_pairs:
        shift = 0
        for match in pattern.finditer(s):
            modification = modifier(match)
            s = (s[:match.start() + shift]
                 + modification
                 + s[match.end() + shift:])
            shift += len(modification) - (match.end() - match.start())
    return s


def format_stream_definitions(definition, problems, temporary_folder):
    def modify_first_problem(match):
        raise_error(
            len(problems) == 0,
            "No problems defined for formatting the stream definition")
        first_problem = problems[0]
        if match.group(2) is not None:
            idx = int(match.group(3))
            idx *= (1 if match.group(2) is None else -1)
            first_problem = first_problem[:idx]
        return first_problem

    def modify_tmp_dir(_):
        raise_error(
            temporary_folder is None,
            "No temporary directory given for formatting the stream definition")
        return temporary_folder

    return regex_format_string(
        definition,
        [(REGEX_STREAM_FIRST_PROBLEM, modify_first_problem),
         (REGEX_STREAM_TEMPORARY_FOLDER, modify_tmp_dir)])


def format_slurm_dependency(match):
    has_user = False
    has_partition = False
    cmd = ["squeue", "--noheader", "-o", "%F_[%K]"]

    pairs = match.group(1).split(";")
    for pair in pairs:
        key, value = pair.split(":", 1)
        if key in ["u", "user"]:
            assert not has_user
            cmd.extend(["--user", value])
            has_user = True
        if key in ["p", "partition"]:
            assert not has_partition
            cmd.extend(["--partition", value])
            has_partition = True

    jobs = subprocess.check_output(cmd)
    jobs = set([match.group(1)
                for match in REGEX_JOB_LIST.finditer(jobs)])
    return ":".join(jobs)


def extract_and_remove_arguments(argv, activate_keys, deactivate_keys):
    """
    Extracts from a list of strings (like parameters of a command)
    those elements belonging to parameters deemed interesting. The input
    list will not be modified. The output is a tuple of the input list without
    the interesting elements and a list of interesting elements.
    :param argv: List of strings
    :param activate_keys: The argument for those parameters are extracted
    :param deactivate_keys: The argument for those parameters are ignored.
    :return: argv without the interesting parameters,
             [[interesting parameter key, arg1, arg2], [...], ...]
    """
    argv = list(argv)
    selected = []
    arg_buffer = None
    idx = 0
    while idx < len(argv):
        if argv[idx] in activate_keys:
            if arg_buffer is not None:
                selected.append(arg_buffer)
            arg_buffer = [argv[idx]]
            del argv[idx]
            idx -= 1

        elif argv[idx] in deactivate_keys:
            if arg_buffer is not None:
                selected.append(arg_buffer)
            arg_buffer = None
        elif arg_buffer is not None:
            arg_buffer.append(argv[idx])
            del argv[idx]
            idx -= 1
        idx += 1

    if arg_buffer is not None:
        selected.append(arg_buffer)
    return argv, selected


def split_on_base_level(to_split, delimiter=":", inc=None, dec=None):
    """
    Splits a string given string at the given delimiter iff the delimiter is
    not nested in a block. E.g.:
    hello:{world:how}:are:{you} -> ['hello', '{world:how}', 'are', '{you}']
    {hello:{world:how}:are:{you}} =>['{hello:{world:how}:are:{you}}']

    :param to_split: string to split
    :param delimiter: delimiter at which to split
    :param inc: iterable of chars denoting an increased nesting (default [{])
    :param dec: iterable of chars denoting an decreased nesting (default [}])
    :return: split string
    """
    inc = ["{"] if inc is None else inc
    dec = ["}"] if dec is None else dec
    assert delimiter not in inc
    assert delimiter not in dec
    assert len(set(inc) & set(dec)) == 0

    def push_buffer(b, s):
        if len(b) > 0:
            s.append(b)
        return ""

    segments = []
    char_buffer = ""
    level = 0
    for c in to_split:
        if c in inc:
            level += 1
        elif c in dec:
            level -= 1

        if c == delimiter and level == 0:
            char_buffer = push_buffer(char_buffer, segments)
        else:
            char_buffer += c
    assert level == 0
    push_buffer(char_buffer, segments)
    return segments


def to_cpp_bool(var):
    assert var is True or var is False
    return "true" if var else "false"


def sas(arg):
    arg = isfile(arg)
    assert arg.endswith(".sas")
    return SAS(arg)



def slurm_dependency(arg):
    return regex_format_string(
        arg,
        [(REGEX_DEPENDENCY_SLURM, format_slurm_dependency)])


NUMERIC_TRANSFORMATIONS = {
    "none": None,
    "ln": np.log,
}


def transform_numeric(arg):
    raise_error(arg not in NUMERIC_TRANSFORMATIONS,
                "Unknown transformation: %s (known: %s)" % (
                    arg, ", ".join(NUMERIC_TRANSFORMATIONS.keys())))
    return NUMERIC_TRANSFORMATIONS[arg]



""">>>>>>>>>> Constructed Object Types >>>>>>>>>>"""


def stream_definition(arg):
    return parser.construct(
        parser_tools.ItemCache(),
        parser_tools.main_register.get_register(StreamDefinition),
        format_stream_definitions(arg, ALL_TASKS, TMP_DIR))


def learner(arg):
    return parser.construct(
        parser_tools.ItemCache(),
        parser_tools.main_register.get_register(Learner),
        arg)

def keras_network(arg):
    return parser.construct(
        parser_tools.ItemCache(),
        parser_tools.main_register.get_register(KerasNetwork),
        arg)


def callback(arg):
    cb = parser.construct(
        parser_tools.ItemCache(),
        parser_tools.main_register.get_register(BaseKerasCallback),
        arg)
    callback.constructed_callbacks.append(cb)
    return cb
callback.constructed_callbacks = []

""">>>>>>>>>> Nested Types >>>>>>>>>>"""


def increase_scrambling(arg):
    if increase_scrambling.phase == 0:
        increase_scrambling.phase = 1
        return callback(arg)
    elif increase_scrambling.phase == 1:
        increase_scrambling.phase = 2
        return positive_int(arg)
    elif increase_scrambling.phase == 2:
        increase_scrambling.phase = 0
        return positive_int(arg)

    else:
        assert False, "Internal error."

increase_scrambling.phase = 0


if __name__ == "__main__":
    print("Thou shall not call me directly.")
    sys.exit(1)
