import inspect

import numpy as np

from gala.units import DimensionlessUnitSystem, UnitSystem
from gala.util import atleast_2d


class PotentialParameter:
    """A class for defining parameters needed by the potential classes

    Parameters
    ----------
    name : str
        The name of the parameter. For example, "m" for mass.
    physical_type : str (optional)
        The physical type (as defined by `astropy.units`) of the expected
        physical units that this parameter is in. For example, "mass" for a mass
        parameter. Pass `None` if the parameter is not meant to be a Quantity (e.g.,
        string or integer values).
    default : numeric, str, array (optional)
        The default value of the parameter.
    equivalencies : `astropy.units.equivalencies.Equivalency` (optional)
        Any equivalencies required for the parameter.
    python_only : bool (optional)
        Controls whether to pass this parameter value to the C/Cython layer. True means
        a parameter is a Python-only value and will not be passed to the C/Cython layer.
        Default is False, meaning by default parameters will be passed to the C/Cython
        layer.
    """

    def __init__(
        self,
        name,
        physical_type="dimensionless",
        default=None,
        equivalencies=None,
        python_only=False,
        ndim=0,
        convert=np.asanyarray,
    ):
        # TODO: could add a "shape" argument?
        # TODO: need better sanitization and validation here

        self.name = str(name)
        self.physical_type = str(physical_type) if physical_type is not None else None
        self.default = default
        self.equivalencies = equivalencies
        self.python_only = bool(python_only)
        self.ndim = int(ndim)
        self.convert = convert

    def __repr__(self):
        if self.physical_type is None:
            return f"<PotentialParameter: {self.name}>"
        return f"<PotentialParameter: {self.name} [{self.physical_type}]>"


class CommonBase:
    def __init_subclass__(cls, GSL_only=False, EXP_only=False, **kwargs):
        # Read the default call signature for the init
        sig = inspect.signature(cls.__init__)

        # Collect all potential parameters defined on the class:
        cls._parameters = {}
        sig_parameters = []

        # Also allow passing parameters in to subclassing:
        subcls_params = kwargs.pop("parameters", {})
        subcls_params.update(cls.__dict__)

        for k, v in subcls_params.items():
            if not isinstance(v, PotentialParameter):
                continue

            cls._parameters[k] = v

            default = inspect.Parameter.empty if v.default is None else v.default

            sig_parameters.append(
                inspect.Parameter(
                    k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=default
                )
            )

        for k, param in sig.parameters.items():
            if k == "self" or param.kind == param.VAR_POSITIONAL:
                continue
            sig_parameters.append(param)
        sig_parameters = sorted(sig_parameters, key=lambda x: int(x.kind))

        # Define a new init signature based on the potential parameters:
        newsig = sig.replace(parameters=tuple(sig_parameters))
        cls.__signature__ = newsig

        super().__init_subclass__(**kwargs)

        cls._GSL_only = GSL_only
        cls._EXP_only = EXP_only

        if not hasattr(cls, "_extra_serialize_args"):
            cls._extra_serialize_args = []

    @classmethod
    def _validate_units(cls, units):
        # make sure the units specified are a UnitSystem instance
        if units is None:
            units = DimensionlessUnitSystem()

        elif isinstance(units, str):
            units = UnitSystem.from_string(units)

        elif not isinstance(units, UnitSystem):
            units = UnitSystem(*units)

        return units

    def _parse_parameter_values(self, *args, strict=True, **kwargs):
        expected_parameter_keys = list(self._parameters.keys())

        if len(args) > len(expected_parameter_keys):
            raise ValueError(
                "Too many positional arguments passed in to "
                f"{self.__class__.__name__}: Potential and Frame classes only "
                "accept parameters as positional arguments, all other "
                "arguments (e.g., units) must now be passed in as keyword "
                "argument."
            )

        parameter_values = {}
        parameter_is_default = set()

        # Get any parameters passed as positional arguments
        i = 0

        if args:
            for i in range(len(args)):
                parameter_values[expected_parameter_keys[i]] = args[i]
            i += 1

        # Get parameters passed in as keyword arguments:
        for k in expected_parameter_keys[i:]:
            if k in kwargs:
                val = kwargs.pop(k)
            else:
                val = self._parameters[k].default
                parameter_is_default.add(k)
            parameter_values[k] = val

        for k, val in parameter_values.items():
            if self._parameters[k].convert is not None:
                parameter_values[k] = self._parameters[k].convert(val)
            else:
                parameter_values[k] = val

        if kwargs and strict:
            raise ValueError(
                f"{self.__class__} received unexpected keyword "
                f"argument(s): {list(kwargs.keys())}"
            )

        for k, pval in parameter_values.items():
            pp = self._parameters[k]
            if pp.physical_type is not None and pval.ndim != pp.ndim:
                raise ValueError(
                    f"Parameter {k} should have ndim={pp.ndim} "
                    f"dimensions, but has ndim={pval.ndim}"
                )

        return parameter_values, parameter_is_default

    def _prepare_parameters(self, parameters, units):
        pars = {}
        for k, v in parameters.items():
            expected_ptype = self._parameters[k].physical_type
            expected_unit = (
                units[expected_ptype] if expected_ptype is not None else None
            )
            equiv = self._parameters[k].equivalencies

            if hasattr(v, "unit"):
                if not isinstance(
                    units, DimensionlessUnitSystem
                ) and not v.unit.is_equivalent(expected_unit, equiv):
                    msg = (
                        f"Parameter {k} has physical type "
                        f"'{v.unit.physical_type}', but we expected a "
                        f"physical type '{expected_ptype}'"
                    )
                    if equiv is not None:
                        msg = (
                            msg + f" or something equivalent via the {equiv} "
                            "equivalency."
                        )

                    raise ValueError(msg)

                # NOTE: this can lead to some comparison issues in __eq__, which
                # tests for strong equality between parameter values. Here, the
                # .to() could cause small rounding issues in comparisons
                if v.unit.physical_type != expected_ptype:
                    v = v.to(expected_unit, equiv)

                v = v.decompose(units)

            elif expected_ptype is not None:
                # this is false for empty ptype: treat empty string as u.one
                # (i.e. this goes to the else clause)

                # TODO: remove when fix potentials that ask for scale velocity!
                if expected_ptype == "speed":
                    v = v * units["length"] / units["time"]
                else:
                    v = v * units[expected_ptype]

                v = v.decompose(units)

            pars[k] = v

        return pars

    def _remove_units_prepare_shape(self, x):
        from gala.dynamics import PhaseSpacePosition

        if hasattr(x, "unit"):
            x = x.decompose(self.units).value

        elif isinstance(x, PhaseSpacePosition):
            x = x.w(self.units)

        return atleast_2d(x, insert_axis=1).astype(np.float64)

    def _get_c_valid_arr(self, x, transpose=True):
        """
        Prepare an array for passing to C: make sure it's 2D and contiguous.

        Parameters
        ----------
        x : array-like
            The input array.
        transpose : bool (optional)
            If True, transpose the array so that shape is (N, ndim). Default is True.

        Returns
        -------
        orig_shape : tuple
            The original shape of the input array.
        x : ndarray
            The reshaped, contiguous array.
        """
        orig_shape = x.shape
        x = x.reshape(orig_shape[0], -1)  # 2D
        if transpose:
            x = x.T
        x = np.ascontiguousarray(x)
        return orig_shape, x

    def _validate_prepare_time(self, t, N_pos):
        """
        Make sure that t is a 1D array and compatible with the C position array.
        """
        if hasattr(t, "unit"):
            t = t.decompose(self.units).value

        if not np.iterable(t):
            t = np.atleast_1d(t)

        t = np.ascontiguousarray(t.ravel())

        if len(t) > 1 and len(t) != N_pos:
            raise ValueError(
                "If passing in an array of times, it must have a shape "
                "compatible with the input position(s)."
            )

        return t

    # For comparison operations
    def __eq__(self, other):
        if other is None or not hasattr(other, "parameters"):
            return False

        # the funkiness in the below is in case there are array parameters:
        par_bool = [
            (k1 == k2) and np.all(self.parameters[k1] == other.parameters[k2])
            for k1, k2 in zip(self.parameters.keys(), other.parameters.keys())
        ]
        return (
            np.all(par_bool)
            and (str(self) == str(other))
            and (self.units == other.units)
        )

    # String representations:
    def __repr__(self):
        pars = []

        keys = self.parameters.keys()
        for k in keys:
            v = self.parameters[k]
            post = ""

            if hasattr(v, "unit"):
                post = f" {v.unit}"
                v = v.value

            if isinstance(v, float):
                if v == 0:
                    par = f"{v:.0f}"
                elif np.log10(np.abs(v)) < -2 or np.log10(np.abs(v)) > 5:
                    par = f"{v:.2e}"
                else:
                    par = f"{v:.2f}"

            elif isinstance(v, int) and np.log10(np.abs(v)) > 5:
                par = f"{v:.2e}"

            else:
                par = str(v)

            pars.append(f"{k}={par}{post}")

        par_str = ", ".join(pars)

        if isinstance(self.units, DimensionlessUnitSystem):
            return f"<{self.__class__.__name__}: {par_str} (dimensionless)>"
        core_units_str = ",".join(map(str, self.units._core_units))
        return f"<{self.__class__.__name__}: {par_str} ({core_units_str})>"

    def __str__(self):
        return self.__class__.__name__
