from src.training import parser_tools as parset
from src.training import parser
from src.training import main_register

from .base_sklearn import BaseSKLearn

from sklearn import linear_model
from sklearn.preprocessing import PolynomialFeatures


class LinearRegression(BaseSKLearn):
    arguments = parset.ClassArguments(
        'LinearRegression', BaseSKLearn.arguments,
        ("degree", True, None, int, "Applies sklearn's PolynomialFeatures with"
                                    "the given degree"),
        order=["x_fields", "y_field", "degree", "load", "store",
               "learner_formats",
               "out", "variables", "id"])

    def __init__(self, x_fields, y_field, degree=None, load=None, store=None,
                 learner_formats=None, out=".",
                 variables=None, id=None):
        BaseSKLearn.__init__(self,x_fields,y_field,
                             load, store, learner_formats, out, variables, id)
        self._degree = degree

    def _initialize_model(self, *args, **kwargs):
        self._model = linear_model.LinearRegression(copy_X=False)

    """----------------------DATA PARSING METHODS----------------------------"""
    def _convert_data(self, *datas):
        """
        The given data is first converted into the format needed for this
        network and then the SampleBatchData objects are finalized. If your
        KerasNetwork subclass needs a different conversion than the default
        given by this class, define in your subclass a staticmethod
        _convert_data(DATA).
        :param data:
        :return:
        """
        x_train, y_train = BaseSKLearn._convert_data(self, *datas)
        if self._degree is not None and self._degree != 1:
            transformer = PolynomialFeatures(degree=self._degree)
            x_train = transformer.fit_transform(x_train)
        return x_train, y_train

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  LinearRegression)


main_register.append_register(LinearRegression, "linear_regression")
