"""Automated tests based on the skbase test suite template."""
import numbers
import types
from copy import deepcopy
from inspect import getfullargspec, isclass, signature

import joblib
import numpy as np
import pandas as pd
from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator
from skbase.testing import QuickTester as _QuickTester
from skbase.testing import TestAllObjects as _TestAllObjects
from skbase.testing.utils.inspect import _get_args

from skpro.registry import OBJECT_TAG_LIST, all_objects
from skpro.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
from skpro.tests.scenarios.scenarios_getter import retrieve_scenarios
from skpro.tests.test_switch import run_test_for_class
from skpro.utils._doctest import run_doctest
from skpro.utils.deep_equals import deep_equals
from skpro.utils.random_state import set_random_state

# whether to test only estimators from modules that are changed w.r.t. main
# default is False, can be set to True by pytest --only_changed_modules True flag
ONLY_CHANGED_MODULES = False


class PackageConfig:
    """Contains package config variables for test classes."""

    # class variables which can be overridden by descendants
    # ------------------------------------------------------

    # package to search for objects
    # expected type: str, package/module name, relative to python environment root
    package_name = "skpro"

    # list of object types (class names) to exclude
    # expected type: list of str, str are class names
    exclude_objects = EXCLUDE_ESTIMATORS

    # list of tests to exclude
    # expected type: dict of lists, key:str, value: List[str]
    # keys are class names of estimators, values are lists of test names to exclude
    excluded_tests = EXCLUDED_TESTS

    # list of valid tags
    # expected type: list of str, str are tag names
    valid_tags = OBJECT_TAG_LIST


class BaseFixtureGenerator(_BaseFixtureGenerator):
    """Fixture generator for base testing functionality in sktime.

    Test classes inheriting from this and not overriding pytest_generate_tests
        will have estimator and scenario fixtures parametrized out of the box.

    Descendants can override:
        estimator_type_filter: str, class variable; None or scitype string
            e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST
            which estimators are being retrieved and tested
        fixture_sequence: list of str
            sequence of fixture variable names in conditional fixture generation
        _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list
            generating list of fixtures for fixture variable with name [variable]
                to be used in test with name test_name
            can optionally use values for fixtures earlier in fixture_sequence,
                these must be input as kwargs in a call
        is_excluded: static method (test_name: str, est: class) -> bool
            whether test with name test_name should be excluded for estimator est
                should be used only for encoding general rules, not individual skips
                individual skips should go on the EXCLUDED_TESTS list in _config
            requires _generate_object_class and _generate_object_instance as is
        _excluded_scenario: static method (test_name: str, scenario) -> bool
            whether scenario should be skipped in test with test_name test_name
            requires _generate_estimator_scenario as is

    Fixtures parametrized
    ---------------------
    object_class: estimator inheriting from BaseObject
        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
    object_instance: instance of estimator inheriting from BaseObject
        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
        instances are generated by create_test_instance class method of object_class
    scenario: instance of TestScenario
        ranges over all scenarios returned by retrieve_scenarios
        applicable for object_class or object_instance
    """

    # overrides object retrieval in scikit-base
    def _all_objects(self):
        """Retrieve list of all object classes of type self.object_type_filter.

        If self.object_type_filter is None, retrieve all objects.
        If class, retrieve all classes inheriting from self.object_type_filter.
        Otherwise (assumed str or list of str), retrieve all classes with tags
        object_type in self.object_type_filter.
        """
        filter = getattr(self, "object_type_filter", None)

        if isclass(filter):
            object_types = filter.get_class_tag("object_type", None)
        else:
            object_types = filter

        obj_list = all_objects(
            object_types=object_types,
            return_names=False,
            exclude_objects=self.exclude_objects,
        )

        if isclass(filter):
            obj_list = [obj for obj in obj_list if issubclass(obj, filter)]

        # run_test_for_class selects the estimators to run
        # based on whether they have changed, and whether they have all dependencies
        # internally, uses the ONLY_CHANGED_MODULES flag,
        # and checks the python env against python_dependencies tag
        obj_list = [obj for obj in obj_list if run_test_for_class(obj)]

        return obj_list

    # which sequence the conditional fixtures are generated in
    fixture_sequence = [
        "object_class",
        "object_instance",
        "scenario",
    ]

    def _generate_scenario(self, test_name, **kwargs):
        """Return estimator test scenario.

        Fixtures parametrized
        ---------------------
        scenario: instance of TestScenario
            ranges over all scenarios returned by retrieve_scenarios
        """
        if "object_class" in kwargs.keys():
            obj = kwargs["object_class"]
        elif "object_instance" in kwargs.keys():
            obj = kwargs["object_instance"]
        else:
            return []

        scenarios = retrieve_scenarios(obj)
        scenarios = [s for s in scenarios if not self._excluded_scenario(test_name, s)]
        scenario_names = [type(scen).__name__ for scen in scenarios]

        return scenarios, scenario_names

    @staticmethod
    def _excluded_scenario(test_name, scenario):
        """Skip list generator for scenarios to skip in test_name.

        Arguments
        ---------
        test_name : str, name of test
        scenario : instance of TestScenario, to be used in test

        Returns
        -------
        bool, whether scenario should be skipped in test_name
        """
        # this line excludes all scenarios that do not have "is_enabled" flag
        #   we should slowly enable more scenarios for better coverage
        # comment out to run the full test suite with new scenarios
        if not scenario.get_tag("is_enabled", False, raise_error=False):
            return True

        return False


class TestAllObjects(PackageConfig, BaseFixtureGenerator, _TestAllObjects):
    """Generic tests for all objects in the mini package."""

    def test_doctest_examples(self, object_class):
        """Runs doctests for estimator class."""
        run_doctest(object_class, name=f"class {object_class.__name__}")

    # override this due to reserved_params index, columns, in the BaseDistribution class
    # index and columns params behave like pandas, i.e., are changed after __init__
    def test_constructor(self, object_class):
        """Check that the constructor has sklearn compatible signature and behaviour.

        Based on sklearn check_estimator testing of __init__ logic.
        Uses create_test_instance to create an instance.
        Assumes test_create_test_instance has passed and certified create_test_instance.

        Tests that:
        * constructor has no varargs
        * tests that constructor constructs an instance of the class
        * tests that all parameters are set in init to an attribute of the same name
        * tests that parameter values are always copied to the attribute and not changed
        * tests that default parameters are one of the following:
            None, str, int, float, bool, tuple, function, joblib memory, numpy primitive
            (other type parameters should be None, default handling should be by writing
            the default to attribute of a different name, e.g., my_param_ not my_param)
        """
        msg = "constructor __init__ should have no varargs"
        assert getfullargspec(object_class.__init__).varkw is None, msg

        estimator = object_class.create_test_instance()
        assert isinstance(estimator, object_class)

        # Ensure that each parameter is set in init
        init_params = _get_args(type(estimator).__init__)
        invalid_attr = set(init_params) - set(vars(estimator)) - {"self"}
        assert not invalid_attr, (
            "Estimator %s should store all parameters"
            " as an attribute during init. Did not find "
            "attributes `%s`." % (estimator.__class__.__name__, sorted(invalid_attr))
        )

        # Ensure that init does nothing but set parameters
        # No logic/interaction with other parameters
        def param_filter(p):
            """Identify hyper parameters of an estimator."""
            return p.name != "self" and p.kind not in [p.VAR_KEYWORD, p.VAR_POSITIONAL]

        init_params = [
            p
            for p in signature(estimator.__init__).parameters.values()
            if param_filter(p)
        ]

        params = estimator.get_params()

        test_params = object_class.get_test_params()
        if isinstance(test_params, list):
            test_params = test_params[0]
        test_params = test_params.keys()

        init_params = [param for param in init_params if param.name not in test_params]

        for param in init_params:
            assert param.default != param.empty, (
                "parameter `%s` for %s has no default value and is not "
                "set in `get_test_params`" % (param.name, estimator.__class__.__name__)
            )
            if type(param.default) is type:
                assert param.default in [np.float64, np.int64]
            else:
                assert type(param.default) in [
                    str,
                    int,
                    float,
                    bool,
                    tuple,
                    type(None),
                    np.float64,
                    types.FunctionType,
                ]

            reserved_params = object_class.get_class_tag("reserved_params", [])
            if param.name not in reserved_params:
                param_value = params[param.name]
                if isinstance(param_value, np.ndarray):
                    np.testing.assert_array_equal(param_value, param.default)
                elif bool(
                    isinstance(param_value, numbers.Real) and np.isnan(param_value)
                ):
                    # Allows to set default parameters to np.nan
                    assert param_value is param.default, param.name
                else:
                    assert param_value == param.default, param.name

    # same here, reserved_params need to be dealt with
    def test_set_params_sklearn(self, object_class):
        """Check that set_params works correctly, mirrors sklearn check_set_params.

        Instead of the "fuzz values" in sklearn's check_set_params,
        we use the other test parameter settings (which are assumed valid).
        This guarantees settings which play along with the __init__ content.
        """
        from skpro.utils.deep_equals import deep_equals

        estimator = object_class.create_test_instance()
        test_params = object_class.get_test_params()
        if not isinstance(test_params, list):
            test_params = [test_params]

        reserved_params = object_class.get_class_tag(
            "reserved_params", tag_value_default=[]
        )

        for params in test_params:
            # we construct the full parameter set for params
            # params may only have parameters that are deviating from defaults
            # in order to set non-default parameters back to defaults
            params_full = object_class.get_param_defaults()
            params_full.update(params)

            msg = f"set_params of {object_class.__name__} does not return self"
            est_after_set = estimator.set_params(**params_full)
            assert est_after_set is estimator, msg

            def unreserved(params):
                return {p: v for p, v in params.items() if p not in reserved_params}

            est_params = estimator.get_params(deep=False)
            is_equal, equals_msg = deep_equals(
                unreserved(est_params), unreserved(params_full), return_msg=True
            )
            msg = (
                f"get_params result of {object_class.__name__} (x) does not match "
                f"what was passed to set_params (y). "
                f"Reason for discrepancy: {equals_msg}"
            )
            assert is_equal, msg

    def test_get_test_params_coverage(self, object_class):
        """Check that get_test_params has good test coverage.

        Checks that:

        * get_test_params returns at least two test parameter sets
        """
        param_list = object_class.get_test_params()

        if isinstance(param_list, dict):
            param_list = [param_list]

        def _coerce_to_list_of_str(obj):
            if isinstance(obj, str):
                return obj
            elif isinstance(obj, list):
                return obj
            else:
                return []

        reserved_param_names = object_class.get_class_tag(
            "reserved_params", tag_value_default=None
        )
        reserved_param_names = _coerce_to_list_of_str(reserved_param_names)
        reserved_set = set(reserved_param_names)

        param_names = object_class.get_param_names()
        unreserved_param_names = set(param_names).difference(reserved_set)

        if len(unreserved_param_names) > 0:
            msg = (
                f"{object_class.__name__}.get_test_params should return "
                f"at least two test parameter sets, but only {len(param_list)} found."
            )
            assert len(param_list) > 1, msg


class TestAllEstimators(PackageConfig, BaseFixtureGenerator, _QuickTester):
    """Package level tests for all sktime estimators, i.e., objects with fit."""

    def test_fit_updates_state(self, object_instance, scenario):
        """Check fit/update state change."""
        # Check that fit updates the is-fitted states
        attrs = ["_is_fitted", "is_fitted"]

        estimator = object_instance
        object_class = type(object_instance)

        msg = (
            f"{object_class.__name__}.__init__ should call "
            f"super({object_class.__name__}, self).__init__, "
            "but that does not seem to be the case. Please ensure to call the "
            f"parent class's constructor in {object_class.__name__}.__init__"
        )
        assert hasattr(estimator, "_is_fitted"), msg

        # Check is_fitted attribute is set correctly to False before fit, at init
        for attr in attrs:
            assert not getattr(
                estimator, attr
            ), f"Estimator: {estimator} does not initiate attribute: {attr} to False"

        fitted_estimator = scenario.run(object_instance, method_sequence=["fit"])

        # Check is_fitted attributes are updated correctly to True after calling fit
        for attr in attrs:
            assert getattr(
                fitted_estimator, attr
            ), f"Estimator: {estimator} does not update attribute: {attr} during fit"

    def test_fit_returns_self(self, object_instance, scenario):
        """Check that fit returns self."""
        fit_return = scenario.run(object_instance, method_sequence=["fit"])
        assert (
            fit_return is object_instance
        ), f"Estimator: {object_instance} does not return self when calling fit"

    def test_fit_does_not_overwrite_hyper_params(self, object_instance, scenario):
        """Check that we do not overwrite hyper-parameters in fit."""
        estimator = object_instance
        set_random_state(estimator)

        # Make a physical copy of the original estimator parameters before fitting.
        params = estimator.get_params()
        original_params = deepcopy(params)

        # Fit the model
        fitted_est = scenario.run(object_instance, method_sequence=["fit"])

        # Compare the state of the model parameters with the original parameters
        new_params = fitted_est.get_params()
        for param_name, original_value in original_params.items():
            new_value = new_params[param_name]

            # We should never change or mutate the internal state of input
            # parameters by default. To check this we use the joblib.hash function
            # that introspects recursively any subobjects to compute a checksum.
            # The only exception to this rule of immutable constructor parameters
            # is possible RandomState instance but in this check we explicitly
            # fixed the random_state params recursively to be integer seeds.
            msg = (
                "Estimator %s should not change or mutate "
                " the parameter %s from %s to %s during fit."
                % (estimator.__class__.__name__, param_name, original_value, new_value)
            )
            # joblib.hash has problems with pandas objects, so we use deep_equals then
            if isinstance(original_value, (pd.DataFrame, pd.Series)):
                assert deep_equals(new_value, original_value), msg
            else:
                assert joblib.hash(new_value) == joblib.hash(original_value), msg
