from src.training import parser_tools as parset
from src.training import parser
from src.training import main_register

from .base_sklearn import BaseSKLearn

from sklearn import svm


class SVR(BaseSKLearn):
    arguments = parset.ClassArguments(
        'SVR', BaseSKLearn.arguments,
        ("kernel", True, "rbf", str, "Kernel to use for the SVR"),
        ("epsilon", True, 0.1, float, "Epsilon"),
        ("C", True, 1, float, "C"),
        order=["x_fields", "y_field", "kernel", "C", "epsilon", "load", "store",
               "learner_formats", "out", "variables", "id"])

    def __init__(self, x_fields, y_field, kernel="rbf", C=1.0, epsilon=0.1,
                 load=None, store=None, learner_formats=None, out=".",
                 variables=None, id=None):
        BaseSKLearn.__init__(self, x_fields, y_field,
                             load, store, learner_formats, out, variables, id)
        self._kernel = kernel
        self._C = C
        self._epsilon = epsilon

    def _initialize_model(self, *args, **kwargs):
        self._model = svm.SVR(
            kernel=self._kernel,
            epsilon=self._epsilon,
            C=self._C)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  SVR)


main_register.append_register(SVR, "svr")
