from .. import parser
from .. import parser_tools as parset
from .. import AbstractBaseClass

from ..parser_tools import main_register, ArgumentException
from ..variable import Variable

import abc
import os


class InvalidMethodCallException(Exception):
    pass


class TrainingOutcome(object):
    name2obj = {}

    def __init__(self, name, started, stopped, failure):
        self.name = name
        self.started = started
        self.stopped = stopped
        self.failure = failure
        assert not self.failure or self.stopped

        self._add_to_enum()

    def _add_to_enum(self):
        setattr(TrainingOutcome, self.name, self)
        TrainingOutcome.name2obj[self.name] = self

    @staticmethod
    def by_name(name):
        return TrainingOutcome.name2obj[name]


TrainingOutcome("NotStarted", False, False, False)
TrainingOutcome("Aborted", True, True, False)
TrainingOutcome("Finished", True, True, False)
TrainingOutcome("Failed", True, True, True)


class LearnerFormat(object):
    name2obj = {}
    suffix2obj = {}

    def __init__(self, name, suffix, description):
        self.name = name
        self.suffix = suffix if isinstance(suffix, list) else [suffix]
        self.description = description
        self._add_to_enum()

    def _add_to_enum(self):
        setattr(LearnerFormat, self.name, self)
        LearnerFormat.name2obj[self.name] = self
        for suffix in self.suffix:
            assert suffix not in LearnerFormat.suffix2obj
            LearnerFormat.suffix2obj[suffix] = self

    def __str__(self):
        s = self.name + "("
        for suffix in self.suffix:
            s += suffix + ", "
        s = s[:-2] + ")"
        return s

    @staticmethod
    def _get(name, mapping):
        if name not in mapping:
            raise KeyError("Unknown key for LearnerFormat: " + str(name))
        return mapping[name]

    @staticmethod
    def by_suffix(suffix):
        return LearnerFormat._get(suffix, LearnerFormat.suffix2obj)

    @staticmethod
    def by_name(name):
        return LearnerFormat._get(name, LearnerFormat.name2obj)

    @staticmethod
    def by_any(key):
        for getter in [LearnerFormat.by_suffix, LearnerFormat.by_name]:
            try:
                return getter(key)
            except KeyError:
                pass
        raise KeyError("Unknown key to identify LearnerFormat: " + key)


LearnerFormat("pickle", "pkl", "Pickled network.")
LearnerFormat("flag", "flag", "Simple flag file. With or without content.")
LearnerFormat("protobuf", "pb", "Protobuf Format")
LearnerFormat("hdf5", "h5", "hdf5 format (e.g. used by Keras)")
LearnerFormat("coefficients", "coef",
              "saves the coefficients of the learner in a json format")
LearnerFormat("running", "running",
              "shows that a training is currently running (check content to"
              " verify that deleting was not forgotten)")


class Learner(AbstractBaseClass):
    """
    Base class for all learners.
    Do not forget to register your network subclass in this packages 'register'
    dictionary via 'append_register' of the main package.
    """

    arguments = parset.ClassArguments(
        "Learner", None,
        ('load', True, None, str, "File to load learner from"),
        ('store', True, None, str, "Path (w/o suffix) where to store learner"),
        ('learner_formats', True, None, LearnerFormat.by_any,
         "Single or list of formats in which to save learner"),
        ('out', True, '.', str, "Path to directory for network outputs"),
        ('variables', True, None, main_register.get_register(Variable)),
        ('id', True, None, str),
    )

    def __init__(self, load=None, store=None, learner_formats=None, out=".",
                 variables=None, id=None):
        variables = {} if variables is None else variables
        if not isinstance(variables, dict):
            raise ArgumentException(
                "The provided variables have to be a map. Please define them "
                "as {name=VARIABLE,...}.")

        self.path_load = os.path.abspath(load) if load is not None else load
        self.path_store = os.path.abspath(store) if store is not None else store
        self.path_out = os.path.abspath(out) if out is not None else out
        self.learner_formats = (learner_formats
                                if isinstance(learner_formats, list)
                                else [learner_formats])
        self._check_store_formats()

        self.msgs = None
        self.variables = {} if variables is None else variables
        self.id = id

        self.initialized = False
        self.finalized = False

    @abc.abstractmethod
    def get_default_format(self):
        """
        Return default/preferred format to store learner of this class.
        :return: LearnerFormat
        """
        pass

    @abc.abstractmethod
    def _get_store_formats(self):
        """
        Return iterable of all formats in which networks of this class can be
        stored,
        :return: iterable (with "in" operator) of LearnerFormat objects
        """
        pass

    @abc.abstractmethod
    def _get_load_formats(self):
        """
        Return iterable of all formats from which networks of this class can be
        loaded
        :return: iterable (with "in" operator) of LearnerFormat objects
        """
        pass

    def get_preferred_state_formats(self):
        """
        Most learners work on sampled PDDL states which can be represented in
        different formats. This method returns the supported StateFormats of
        the network in order of preference.
        If your network does not work on those states, let the method raise an
        exception (like it currently does)
        :return: list of supported StateFormat objects (see SamplingBridges)
        """
        raise InvalidMethodCallException("The learner does not support "
                                         "state formats.")

    def initialize(self, msgs, *args, **kwargs):
        """
        Build network object, load model (if requested) and prepare
        :param msgs: Message object for communication between objects (if given)
        :return:
        """
        skip_loading = kwargs.get("skip_loading", False)
        if self.initialized:
            raise InvalidMethodCallException(
                "Multiple initializations of network.")
        self.msgs = msgs

        self._initialize_general(*args, **kwargs)
        assert (self.path_load is None or skip_loading or
                os.path.exists(self.path_load) or
                any(os.path.exists(self.path_load + "." + s)
                    for x in self._get_load_formats() for s in x.suffix)), self.path_load
        if self.path_load is not None and not skip_loading:
            self.load(**kwargs)
        else:
            self._initialize_model(*args, **kwargs)

        self.initialized = True

    @abc.abstractmethod
    def reinitialize(self, *args, **kwargs):
        """
        Reinitialize the network (what to do may depend on the network,
        e.g. if loading a network from a file, there might be nothing to do,
        on other occasions, the weights might be new randomly initialized
        """
        pass

    @abc.abstractmethod
    def _initialize_general(self, *args, **kwargs):
        """Initialization code except for the model initialization"""
        pass

    @abc.abstractmethod
    def _initialize_model(self, *args, **kwargs):
        """Initialization of the model (if it is not loaded!)"""
        pass

    def finalize(self, *args, **kwargs):
        if not self.finalized:
            self._finalize(*args, **kwargs)
            if self.path_store is not None:
                self.store()
            self.finalized = True
        else:
            raise InvalidMethodCallException("Multiple finalization calls of"
                                             "network.")

    @abc.abstractmethod
    def _finalize(self, *args, **kwargs):
        pass

    def load(self, path=None, learner_format=None, **kwargs):
        """
        Load network from a file
        :param path: If given, file to load network from, else the path given
                     at construction time is used. If this is also not given,
                     a ValueError is raised.
        :param learner_format: Network format of the file to load. If not given,
                               it is inferred from the files suffix if possible.
        :return:
        """
        path = self.path_load if path is None else path
        if path is None:
            raise ValueError("No path defined for loading of a network")
        if not os.path.exists(path):
            found = False
            if learner_format is None:
                for x in self._get_load_formats():
                    for s in x.suffix:
                        path_tmp = "%s.%s" % (path, s)
                        if os.path.exists(path_tmp):
                            path = path_tmp
                            learner_format = x
                            found = True
                            break
                    if found:
                        break
            if not found:
                raise ValueError("File does not exists to load network form:"
                                 + str(path))
        if learner_format is None:
            suffix = os.path.splitext(path)[1][1:]
            learner_format = LearnerFormat.by_suffix(suffix)
        if learner_format not in self._get_load_formats():
            raise ValueError("The learner file to load is not of a format"
                             " supported for loading: " + str(learner_format))
        self._load(path, learner_format, **kwargs)

    @abc.abstractmethod
    def _load(self, path, learner_format, *args, **kwargs):
        pass

    def get_store_path(self, path=None):
        path = self.path_store if path is None else path
        path = os.path.join("." if self.path_out is None else self.path_out,
                            "model") if path is None else path
        return path

    def store(self, path=None, learner_formats=None, allow_uninitialized=False,
              path_suffix=None):
        """
        Stores the network in the specified formats.
        :param allow_uninitialized:
        :param path: Path without suffix where to store the network. If None,
                     the path given at construction is used. If even this is
                     None, the output directory defined at construction is
                     used and if this is none './network' is used.
        :param learner_formats:
            Iterable of LearnerFormat in which to store the network.
            If None is given, then the formats given at construction
            are used. If None were given there, then the default
            format is used.
        :param path_suffix: suffix that is added after the path prior to the
                            file extension
        """
        path = self.get_store_path(path)
        if path_suffix is not None:
            path += path_suffix
        learner_formats = (self.learner_formats if learner_formats is None
                           else learner_formats)
        self._check_store_formats(learner_formats)
        if self._model is not None or allow_uninitialized:
            self._store(path, learner_formats)
        else:
            raise ValueError("Uninitialized model cannot be saved!")

    @abc.abstractmethod
    def _store(self, path, learner_formats):
        pass

    def _check_store_formats(self, learner_formats=None):
        learner_formats = (self.learner_formats if learner_formats is None
                           else learner_formats)
        sf = self._get_store_formats()
        for f in learner_formats:
            if not (f is None or f in sf):
                raise ValueError("The learner does not support storing "
                                 "a desired format: " + str(f))

    def _convert_data(self, data):
        """
        The given data is first converted into the format needed for this
        network and then the SampleBatchData objects are finalized. If your
        KerasNetwork subclass needs a different conversion than the default
        given by this class, define in your subclass a staticmethod
        _convert_data(DATA).
        :param data:
        :return:
        """
        pass

    @abc.abstractmethod
    def train(self, data):
        """
        Train the network on the given data
        :param data:
        :return: dict with interesting values. Use 'failure': bool to indicate
                 if the training has failed.
        """
        pass

    def train_reinforcement(self, data_generator):
        """
        Train the learner via reinforcement learning (or not)
         the input data is given by an generating object.
        :param data_generator: Todo: what are the requirements?
        :return:
        """
        raise NotImplementedError()

    @abc.abstractmethod
    def evaluate(self, data):
        """

        :param data:
        :return:
        """
        pass

    def analyse(self, directory=None, prefix=""):
        """
        Analyse the network performance.
        This functionality is optional and not every networks supports it.
        :param directory: Path to directory for storing the analysis results.
                          If None is given, the output directory given at
                          construction time is used. If this is also None,
                          the current working directory is used.
        :param prefix: A prefix which shall be added in front of every file name
                       which the analysis is producing. If None is given, the
                       no prefix is used.
        :return:
        """
        directory = self.path_out if directory is None else directory
        directory = "." if directory is None else directory
        self._analyse(directory, prefix)

    def _analyse(self, directory, prefix):
        pass

    @staticmethod
    def parse(tree, item_cache):
        obj = parser.try_lookup_obj(tree, item_cache, Learner, None)
        if obj is not None:
            return obj
        else:
            raise ArgumentException("The definition of the base network can "
                                    "only be used for look up of any previously"
                                    " defined schema via 'Sampler(id=ID)'")


main_register.append_register(Learner, "learner")
lregister = main_register.get_register(Learner)
