from __future__ import print_function

import argparse
import enum
import os
import re
import sys


NAMED_ARGS = {}
ALL_TASKS = []
TMP_DIR = None


def raise_error(condition, msg):
    if condition:
        raise argparse.ArgumentTypeError(msg)


def combine(*funcs):
    def _combine(arg):
        for func in funcs:
            arg = func(arg)
        return arg
    return _combine


""">>>>>>>>>> Simple Types >>>>>>>>>>"""


def float_interval(min_value=None, max_value=None):
    def _float_interval(arg):
        arg = float(arg)
        raise_error(min_value is not None and arg < min_value,
                    "Minimum value is %s" % str(min_value))
        raise_error(max_value is not None and arg > max_value,
                    "Maximum value is %s" % str(max_value))
        return arg
    return _float_interval


float_zero_positive = float_interval(min_value=0)


def int_interval(min_value=None, max_value=None):
    def _int_interval(arg):
        raise_error(int(arg) != float(arg), "%s is not an integer" % arg)
        arg = int(arg)
        raise_error(min_value is not None and arg < min_value,
                    "Minimum value is %s" % str(min_value))
        raise_error(max_value is not None and arg > max_value,
                    "Maximum value is %s" % str(max_value))
        return arg
    return _int_interval


int_zero_positive = int_interval(min_value=0)
int_positive = int_interval(min_value=1)


def isfile(arg):
    raise_error(not os.path.isfile(arg), "%s does not exist." % arg)
    return arg


absfile = combine(isfile, os.path.abspath)


def isdir(arg):
    raise_error(not os.path.isdir(arg), "%s is not a directory" % arg)
    return arg


absdir = combine(isdir, os.path.abspath)


def suffix(sfx):
    def _suffix(arg):
        raise_error(
            not arg.endswith(sfx),
            "Does not end with the required suffix (%s): %s" % (sfx, arg))
        return arg
    return _suffix


isbash = combine(isfile, suffix(".sh"))


def istask(arg):
    arg = isfile(arg)
    ALL_TASKS.append(arg)
    return arg


def allow_environment_variable(func):
    def _allow_environment_variable(arg):
        if arg.startswith("$"):
            arg = os.environ.get(arg[1:])
            raise_error(
                arg is None, "Variable is not defined: %s" % arg)
        return func(arg)
    return _allow_environment_variable


def tmpdir(arg):
    global TMP_DIR
    raise_error(TMP_DIR is not None, "TMP_DIR already set.")
    TMP_DIR = isdir(arg)
    return TMP_DIR


def units(regex, unit_to_base):
    def f(arg):
        m = regex.match(arg)
        raise_error(m is None,
                    "Argument does not match the unit syntax: %s" % arg)
        scalar, unit = m.groups()
        scalar = float(scalar)
        raise_error(unit is not None and unit not in unit_to_base,
                    "Unknown unit: %s" % unit)
        return scalar if unit is None else (scalar * unit_to_base[unit])
    return f


TIME_UNITS = {"s": 1, "": 1, "m": 60, "h": 3600, "d": 86400}
REGEX_TIME = re.compile(r"^(\d+(?:\.\d+)?)(%s)?$" %
                        "|".join(TIME_UNITS.keys()))
time = units(REGEX_TIME, TIME_UNITS)

# Regex Stuff for parsing memory
MEMORY_UNITS = {"kb": 1024, "mb": 1024**2, "gb": 1024**3, "tb": 1024**4}
REGEX_MEMORY = re.compile(r"^(\d+(?:\.\d+)?)(%s)?$" %
                          "|".join(MEMORY_UNITS.keys()))
memory = units(REGEX_MEMORY, MEMORY_UNITS)


REGEX_INT_MODIFICATION = re.compile(r"([+\-*/])?(\d+)")


def int_modification(arg):
    m = REGEX_INT_MODIFICATION.match(arg)
    raise_error(
        m is None,
        "Argument does not match the pattern (%s): %s" % (
            REGEX_INT_MODIFICATION.pattern, arg))
    modifier, value = m.groups()
    return ("+" if modifier is None else modifier), int(value)


REGEX_FLOAT_MODIFICATION = re.compile(r"([+\-*/])?(\d+(.\d+)?)")


def float_modification(arg):
    m = REGEX_FLOAT_MODIFICATION.match(arg)
    raise_error(
        m is None,
        "Argument does not match the pattern (%s): %s" % (
            REGEX_FLOAT_MODIFICATION.pattern, arg))
    modifier, value = m.groups()[0:2]
    return ("+" if modifier is None else modifier), float(value)


def has_seeds(arg):
    key = "{seed}"
    sub = "{seed%i}"
    c = arg.count(key)
    for i in range(c):
        arg = arg.replace(key, sub % (has_seeds.count + i), 1)
    raise_error(arg.count(key) != 0, "Internal error")
    has_seeds.count += c
    return arg


has_seeds.count = 0


""">>>>>>>>>> Nested Types >>>>>>>>>>"""


def named_type(func, name):
    def _named_type(arg):
        arg = func(arg)
        NAMED_ARGS[name] = arg
        return arg
    return _named_type


def restricted_type(func, check_func):
    def _restricted_type(arg):
        arg = func(arg)
        check_func(arg, NAMED_ARGS)
        return arg
    return _restricted_type


def split_type(char, func1=str, func2=str):
    def _split_type(arg):
        idx = arg.find(char)
        raise_error(idx == -1, "Cannot find splitting char %s in %s" %
                    (char, arg))
        return func1(arg[:idx]), func2(arg[idx + 1:])
    return _split_type


class ListModes(enum.Enum):
    Loop = "loop"  # cycles through the functions in the list
    Last = "last"  # repeats at the end only the last function


def list_type(*funcs, **kwargs):
    mode = kwargs.get("mode", ListModes.Loop)

    def _list_type(arg):
        raise_error(_list_type.phase >= len(funcs), "Internal error")
        arg = funcs[_list_type.phase](arg)
        _list_type.phase += 1

        if _list_type.phase >= len(funcs):
            if ListModes.Loop == mode:
                _list_type.phase = 0
            elif ListModes.Last == mode:
                _list_type.phase = _list_type.phase - 1
            else:
                raise_error(True, "Internal error")
        return arg
    _list_type.phase = 0
    return _list_type


def choice_type(choices):
    def _choice_type(arg):
        raise_error(
            arg not in choices,
            "Argument (%s) is not a valid choice (%s)" % (
                arg, ", ".join(choices)))
        return arg
    return _choice_type


def placeholder_type(func, placeholders, parse_placeholders=True):
    def _placeholder_type(arg):
        if arg in placeholders:
            return (func(placeholders[arg]) if parse_placeholders else
                    placeholders[arg])
        else:
            return func(arg)
    return _placeholder_type


def default(func, default_value):
    return placeholder_type(func, {"default": default_value})


"""------------------------- argparse Restrictions --------------------------"""


def check_min_max_restriction(min_name, max_name):
    def _check_min_max_restriction(_, named_args):
        min_arg = named_args.get(min_name)
        max_arg = named_args.get(max_name)
        raise_error(
            not (min_arg is None or max_arg is None or min_arg <= max_arg),
            "Minimum value has to be smaller or equal to max value"
        )
    return _check_min_max_restriction


def check_buffer(func_init, func_update, func_check):
    def _check_buffer(arg, named_args):
        _check_buffer.cache = func_update(_check_buffer.cache, arg, named_args)
        func_check(_check_buffer.cache)
    _check_buffer.cache = func_init()
    return _check_buffer


ARG_COUNTS_FUNCTIONS = []


def arg_counts(name, func, intervals=None, values=None, required=False):
    if intervals is not None:
        if not all(len(x) == 2 for x in intervals):
            raise ValueError("arg_counts: all intervals have to be of size 2 "
                             "(e.g. [(1,4), (5,10)].")
        if not all(x[0] <= x[1] for x in intervals):
            raise ValueError("arg_counts: all intervals have to be of the form"
                             "(x, y) with x <= y.")
    if values is not None:
        try:
            _ = iter(values)
        except TypeError:
            raise ValueError("arg_counts: values has to be an iterable.")

    def _arg_counts(*args, **kwargs):
        obj = func(*args, **kwargs)
        _arg_counts.count += 1
        return obj

    _arg_counts.count = 0
    _arg_counts.intervals = intervals
    _arg_counts.values = values
    _arg_counts.name = name
    _arg_counts.required = required
    ARG_COUNTS_FUNCTIONS.append(_arg_counts)
    return _arg_counts


def check_all_arg_counts_arguments():
    errors = []

    def _add_error(name, msg):
        errors.append("{name}: {msg}".format(**locals()))
    for acf in ARG_COUNTS_FUNCTIONS:
        if acf.count == 0:
            if acf.required:
                _add_error(acf.name, "Argument is required")
            continue
        is_okay = acf.intervals is None and acf.values is None
        for interval in ([] if acf.intervals is None or is_okay
                         else acf.intervals):
            if interval[0] <= acf.count <= interval[1]:
                is_okay = True
                break
        for value in [] if is_okay or acf.values is None else acf.values:
            if acf.count == value:
                is_okay = True
                break
        if not is_okay:
            _add_error(acf.name, "Invalid argument count {acf.count}".format(
                **locals()))
    if len(errors) > 0:
        raise ValueError("\n".join(errors))


if __name__ == "__main__":
    print("Thou shall not call me directly.")
    sys.exit(1)
