from src.training.learners import Learner, LearnerFormat, TrainingOutcome

from src.training import parser_tools as parset
from src.training import parser
from src.training import main_register

from src.training.bridges import StateFormat

from ...parser_tools import ArgumentException

import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error

class BaseSKLearn(Learner):
    arguments = parset.ClassArguments(
        'SVR', Learner.arguments,
        ("x_fields", False, None, str,
         "Fields to use as input (features of fields are concatenated"),
        ("y_field", False, None, str,
         "Field name for label"),
        order=["x_fields", "y_field", "load", "store",
               "learner_formats", "out", "variables", "id"])

    def __init__(self, x_fields, y_field,
                 load=None, store=None, learner_formats=None, out=".",
                 variables=None, id=None):
        Learner.__init__(self, load, store, learner_formats, out, variables, id)
        self._x_fields = x_fields if isinstance(x_fields, list) else list(x_fields)
        self._y_field = y_field
        self._model = None

        """Analysis data"""
        self._evaluation = None
        self._evaluations = []

    def get_default_format(self):
        return LearnerFormat.coefficients

    def _get_store_formats(self):
        return set([LearnerFormat.coefficients,
                    LearnerFormat.flag])

    def _get_load_formats(self):
        return set([LearnerFormat.coefficients])

    def get_preferred_state_formats(self):
        return [StateFormat.All_A_01]

    def _load(self, path, format):
        if learner_formats == LearnerFormat.coefficients:
            print("Warning: actually, we cannot load this...")
        else:
            raise ValueError("Linear Regression cannot be loaded from: " + str(format))

    def _store(self, path, learner_formats):
        for learner_format in learner_formats:
            if learner_format is None:
                learner_format = self.get_default_format()

            path_format = path + "." + learner_format.suffix[0]
            if learner_format == LearnerFormat.coefficients:
                print("Warning: actually, we cannot save this...")
            elif learner_format == LearnerFormat.flag:
                with open(path_format, "w") as f:
                    pass

    def reinitialize(self, *args, **kwargs):
        if self.path_load is not None:
            self.load(**kwargs)
        else:
            self._initialize_model(*args, **kwargs)

    def _initialize_general(self, *args, **kwargs):
        pass

    def _finalize(self):
        pass

    def train(self, dtrain, dvalid=None):
        """
        The given data is first converted into the format needed for this
        network and then the SampleBatchData objects are finalized. If your
        KerasNetwork subclass needs a different conversion than the default
        given by this class, define in your subclass a staticmethod
        _convert_data(DATA).
        :param dtrain: List of SampleBatchData for training
        :param dvalid: List of SampleBatchData for testing
        :return:
        """


        if dvalid == dtrain:
            dvalid = None
        x_train, y_train = self._convert_data(dtrain, dvalid)

        self._model.fit(x_train, y_train)

        return {"training_outcome": TrainingOutcome.Finished}

    def evaluate(self, data):
        x_test, y_test = self._convert_data(data)
        y_pred = self._model.predict(x_test)

        result = (y_pred, y_test)

        self._evaluation = result
        self._evaluations.append(result)
        return result

    """----------------------DATA PARSING METHODS----------------------------"""
    def _convert_data(self, *datas):
        """
        The given data is first converted into the format needed for this
        network and then the SampleBatchData objects are finalized. If your
        KerasNetwork subclass needs a different conversion than the default
        given by this class, define in your subclass a staticmethod
        _convert_data(DATA).
        :param data:
        :return:
        """
        x_train = []
        y_train = []
        field_sizes = None

        for data in datas:
            if data is None:
                continue
            data = data if isinstance(data, list) else [data]
            for data_set in data:
                if data_set.is_finalized:
                    print("Warning: Data set previously finalized. Skipping now.")
                    continue
                data_set.finalize()



            for sbd in data:
                idx_y = sbd.fields.index(self._y_field)
                idxs_x = [sbd.fields.index(x) for x in self._x_fields]
                for entry_type in sbd.data.keys():
                    for batch in sbd.data[entry_type]:
                        if len(batch) == 0:
                            continue
                        if field_sizes is None:
                            field_sizes = [len(batch[0, x]) for x in idxs_x]
                        assert all(field_sizes[no] == len(batch[i, x])
                                   for i in range(len(batch))
                                   for no, x in enumerate(idxs_x))
                        x_train.append(
                            np.concatenate([np.stack(batch[:, x]) for x in idxs_x],
                                           axis=1))
                        y_train.append(batch[:, idx_y])
        x_train = np.concatenate(x_train)
        y_train = np.concatenate(y_train)
        return x_train, y_train

    """-------------------------ANALYSE PREDICATIONS-------------------------"""
    def _analyse(self, directory, prefix):
        y_pred, y_test = self._evaluation
        print("Evaluation MSE:", mean_squared_error(y_pred, y_test))
        print("Evaluation MAE:", mean_absolute_error(y_pred, y_test))

    """-------------------------OTHER METHODS--------------------------------"""


    @staticmethod
    def parse(tree, item_cache):
        obj = parser.try_lookup_obj(tree, item_cache, BaseSKLearn, None)
        if obj is not None:
            return obj
        else:
            raise ArgumentException("The definition of the base sklearner can "
                                    "only be used for look up of any previously"
                                    " defined schema via 'sklearner(id=ID)'")


main_register.append_register(BaseSKLearn, "sklearner")
