from src.training import parser_tools as parset
from src.training import parser
from src.training import main_register

from .base_sklearn import BaseSKLearn

from sklearn import ensemble


class RandomForestRegressor(BaseSKLearn):
    arguments = parset.ClassArguments(
        'RandomforestRegressor', BaseSKLearn.arguments,
        ("n_estimators", True, 100, int, "Number of trees"),
        order=["x_fields", "y_field", "n_estimators", "load", "store",
               "learner_formats", "out", "variables", "id"])

    def __init__(self, x_fields, y_field, n_estimators=100,
                 load=None, store=None, learner_formats=None, out=".",
                 variables=None, id=None):
        BaseSKLearn.__init__(self, x_fields, y_field,
                             load, store, learner_formats, out, variables, id)
        self._n_estimators = n_estimators

    def _initialize_model(self, *args, **kwargs):
        self._model = ensemble.RandomForestRegressor(
            n_estimators=self._n_estimators,
        )

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  RandomForestRegressor)


main_register.append_register(RandomForestRegressor, "random_forest_regressor")
