import warnings

import astropy.units as u
import numpy as np

from gala.potential.potential.builtin.core import (
    HernquistPotential,
    LogarithmicPotential,
    MiyamotoNagaiPotential,
    MN3ExponentialDiskPotential,
    NFWPotential,
    PowerLawCutoffPotential,
)
from gala.potential.potential.ccompositepotential import CCompositePotential
from gala.units import galactic
from gala.util import GalaFutureWarning

__all__ = [
    "BovyMWPotential2014",
    "LM10Potential",
    "MilkyWayPotential",
    "MilkyWayPotential2022",
]


class LM10Potential(CCompositePotential):
    """
    The Galactic potential used by Law and Majewski (2010) to represent
    the Milky Way as a three-component sum of disk, bulge, and halo.

    The disk potential is an axisymmetric
    :class:`~gala.potential.MiyamotoNagaiPotential`, the bulge potential
    is a spherical :class:`~gala.potential.HernquistPotential`, and the
    halo potential is a triaxial :class:`~gala.potential.LogarithmicPotential`.

    Default parameters are fixed to those found in LM10 by fitting N-body
    simulations to the Sagittarius stream.

    Parameters
    ----------
    units : `~gala.units.UnitSystem` (optional)
        Set of non-reducable units that specify (at minimum) the
        length, mass, time, and angle units.
    disk : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.MiyamotoNagaiPotential`.
    bulge : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.HernquistPotential`.
    halo : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.LogarithmicPotential`.

    Note: in subclassing, order of arguments must match order of potential
    components added at bottom of init.
    """

    def __init__(self, units=galactic, disk=None, bulge=None, halo=None):
        if halo is None:
            halo = {}
        if bulge is None:
            bulge = {}
        if disk is None:
            disk = {}
        default_disk = {"m": 1e11 * u.Msun, "a": 6.5 * u.kpc, "b": 0.26 * u.kpc}
        default_bulge = {"m": 3.4e10 * u.Msun, "c": 0.7 * u.kpc}
        default_halo = {
            "q1": 1.38,
            "q2": 1.0,
            "q3": 1.36,
            "r_h": 12.0 * u.kpc,
            "phi": 97 * u.degree,
            "v_c": np.sqrt(2) * 121.858 * u.km / u.s,
        }

        for k, v in default_disk.items():
            if k not in disk:
                disk[k] = v

        for k, v in default_bulge.items():
            if k not in bulge:
                bulge[k] = v

        for k, v in default_halo.items():
            if k not in halo:
                halo[k] = v

        super().__init__()

        self["disk"] = MiyamotoNagaiPotential(units=units, **disk)
        self["bulge"] = HernquistPotential(units=units, **bulge)
        self["halo"] = LogarithmicPotential(units=units, **halo)
        self.lock = True


# ============================================================================
# Gala MilkyWayPotential
#


def _setup_mwp_v1(obj, units, **kwargs):
    default_disk = {"m": 6.8e10 * u.Msun, "a": 3.0 * u.kpc, "b": 0.28 * u.kpc}
    default_bulge = {"m": 5e9 * u.Msun, "c": 1.0 * u.kpc}
    default_nucl = {"m": 1.71e9 * u.Msun, "c": 0.07 * u.kpc}
    default_halo = {"m": 5.4e11 * u.Msun, "r_s": 15.62 * u.kpc}

    disk = kwargs.get("disk", {})
    bulge = kwargs.get("bulge", {})
    halo = kwargs.get("halo", {})
    nucleus = kwargs.get("nucleus", {})

    for k, v in default_disk.items():
        disk.setdefault(k, v)

    for k, v in default_bulge.items():
        bulge.setdefault(k, v)

    for k, v in default_halo.items():
        halo.setdefault(k, v)

    for k, v in default_nucl.items():
        nucleus.setdefault(k, v)

    obj["disk"] = MiyamotoNagaiPotential(units=units, **disk)
    obj["bulge"] = HernquistPotential(units=units, **bulge)
    obj["nucleus"] = HernquistPotential(units=units, **nucleus)
    obj["halo"] = NFWPotential(units=units, **halo)


def _setup_mwp_2022(obj, units, **kwargs):
    default_disk = {"m": 4.7717e10 * u.Msun, "h_R": 2.6 * u.kpc, "h_z": 0.3 * u.kpc}
    default_bulge = {"m": 5e9 * u.Msun, "c": 1.0 * u.kpc}
    default_nucl = {"m": 1.8142e9 * u.Msun, "c": 0.0688867 * u.kpc}
    default_halo = {"m": 5.5427e11 * u.Msun, "r_s": 15.626 * u.kpc}

    disk = kwargs.get("disk", {})
    halo = kwargs.get("halo", {})
    bulge = kwargs.get("bulge", {})
    nucleus = kwargs.get("nucleus", {})

    for k, v in default_disk.items():
        disk.setdefault(k, v)

    for k, v in default_bulge.items():
        bulge.setdefault(k, v)

    for k, v in default_halo.items():
        halo.setdefault(k, v)

    for k, v in default_nucl.items():
        nucleus.setdefault(k, v)

    obj["disk"] = MN3ExponentialDiskPotential(units=units, **disk)
    obj["bulge"] = HernquistPotential(units=units, **bulge)
    obj["nucleus"] = HernquistPotential(units=units, **nucleus)
    obj["halo"] = NFWPotential(units=units, **halo)


class MilkyWayPotential(CCompositePotential):
    """
    A simple mass-model for the Milky Way consisting of a spherical nucleus and
    bulge, a Miyamoto-Nagai disk, and a spherical NFW dark matter halo.

    The disk model is taken from `Bovy (2015)
    <https://ui.adsabs.harvard.edu/#abs/2015ApJS..216...29B/abstract>`_ - if you
    use this potential, please also cite that work.

    Default parameters are fixed by fitting to a compilation of recent mass
    measurements of the Milky Way, from 10 pc to ~150 kpc.

    Parameters
    ----------
    units : `~gala.units.UnitSystem` (optional)
        Set of non-reducable units that specify (at minimum) the
        length, mass, time, and angle units.
    disk : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.MiyamotoNagaiPotential`.
    bulge : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.HernquistPotential`.
    halo : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.NFWPotential`.
    nucleus : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.HernquistPotential`.

    Note: in subclassing, order of arguments must match order of potential
    components added at bottom of init.
    """

    _extra_serialize_args = ["version"]

    def __init__(self, version=None, units=galactic, **kwargs):
        super().__init__()

        # TODO: remove when MilkyWayPotential API changes
        if version is None:
            warnings.warn(
                "In a future version of Gala, the current MilkyWayPotential and "
                "MilkyWayPotential2022 classes will be combined into a single class, "
                "MilkyWayPotential, with an optional 'version' argument to select "
                "between the models. To use the old (version 1) MilkyWayPotential, "
                'specify version="v1" when creating an instance. To use the newer '
                '(version 2 = current MilkyWayPotential2022), specify version="v2".',
                GalaFutureWarning,
            )
            version = "v1"

        self.version = str(version).lower()

        if self.version in ("latest", "v2"):
            _setup_mwp_2022(self, units, **kwargs)

        elif self.version == "v1":
            _setup_mwp_v1(self, units, **kwargs)

        else:
            raise ValueError(
                f"Invalid MilkyWayPotential version: {version}. Supported values are: "
                "(v1, v2, latest)"
            )

        self.lock = True


class MilkyWayPotential2022(CCompositePotential):
    """
    A mass-model for the Milky Way consisting of a spherical nucleus and bulge, a
    3-component sum of Miyamoto-Nagai disks to represent an exponential disk, and a
    spherical NFW dark matter halo.

    The disk model is fit to the Eilers et al. 2019 rotation curve for the radial
    dependence, and the shape of the phase-space spiral in the solar neighborhood is
    used to set the vertical structure in Darragh-Ford et al. 2023.

    Other parameters are fixed by fitting to a compilation of recent mass measurements
    of the Milky Way, from 10 pc to ~150 kpc.

    Parameters
    ----------
    units : `~gala.units.UnitSystem` (optional)
        Set of non-reducable units that specify (at minimum) the
        length, mass, time, and angle units.
    disk : dict (optional)
        Parameters to be passed to the
        :class:`~gala.potential.MN3ExponentialDiskPotential`.
    bulge : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.HernquistPotential`.
    halo : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.NFWPotential`.
    nucleus : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.HernquistPotential`.

    Note: in subclassing, order of arguments must match order of potential
    components added at bottom of init.
    """

    def __init__(self, units=galactic, disk=None, halo=None, bulge=None, nucleus=None):
        super().__init__()

        # TODO: remove when MilkyWayPotential API changes
        warnings.warn(
            "The MilkyWayPotential2022 class will be deprecated soon. Instead, use: "
            'MilkyWayPotential(version="v2") to get what is currently the '
            "MilkyWayPotential2022 class. Or, to always use the latest Milky Way model "
            "in Gala, you can call the class with no arguments MilkyWayPotential() or "
            'specify MilkyWayPotential(version="latest")',
            GalaFutureWarning,
        )
        disk = {} if disk is None else disk
        halo = {} if halo is None else halo
        bulge = {} if bulge is None else bulge
        nucleus = {} if nucleus is None else nucleus
        _setup_mwp_2022(self, units, disk=disk, halo=halo, bulge=bulge, nucleus=nucleus)

        self.lock = True


class BovyMWPotential2014(CCompositePotential):
    """
    An implementation of the ``MWPotential2014``
    `from galpy <https://galpy.readthedocs.io/en/latest/potential.html>`_
    and described in `Bovy (2015)
    <https://ui.adsabs.harvard.edu/#abs/2015ApJS..216...29B/abstract>`_.

    This potential consists of a spherical bulge and dark matter halo, and a
    Miyamoto-Nagai disk component.

    .. note::

        Because it internally uses the PowerLawCutoffPotential,
        this potential requires GSL to be installed, and Gala must have been
        built and installed with GSL support enaled (the default behavior).
        See http://gala.adrian.pw/en/latest/install.html for more information.

    Parameters
    ----------
    units : `~gala.units.UnitSystem` (optional)
        Set of non-reducable units that specify (at minimum) the
        length, mass, time, and angle units.
    disk : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.MiyamotoNagaiPotential`.
    bulge : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.PowerLawCutoffPotential`.
    halo : dict (optional)
        Parameters to be passed to the :class:`~gala.potential.NFWPotential`.

    Note: in subclassing, order of arguments must match order of potential
    components added at bottom of init.
    """

    def __init__(self, units=galactic, disk=None, halo=None, bulge=None):
        default_disk = {
            "m": 68193902782.346756 * u.Msun,
            "a": 3.0 * u.kpc,
            "b": 280 * u.pc,
        }
        default_bulge = {
            "m": 4501365375.06545 * u.Msun,
            "alpha": 1.8,
            "r_c": 1.9 * u.kpc,
        }
        default_halo = {"m": 4.3683325e11 * u.Msun, "r_s": 16 * u.kpc}

        if disk is None:
            disk = {}

        if halo is None:
            halo = {}

        if bulge is None:
            bulge = {}

        for k, v in default_disk.items():
            if k not in disk:
                disk[k] = v

        for k, v in default_bulge.items():
            if k not in bulge:
                bulge[k] = v

        for k, v in default_halo.items():
            if k not in halo:
                halo[k] = v

        super().__init__()

        self["disk"] = MiyamotoNagaiPotential(units=units, **disk)
        self["bulge"] = PowerLawCutoffPotential(units=units, **bulge)
        self["halo"] = NFWPotential(units=units, **halo)
        self.lock = True


# --------------------------------------------------------------------
# class TriaxialMWPotential(CCompositePotential):

#     def __init__(self, units=galactic,
#                  disk=dict(), bulge=dict(), halo=dict()):
#         """ Axis ratio values taken from Jing & Suto (2002). Other
#             parameters come from a by-eye fit to Bovy's MW2014Potential.
#             Choice of v_c sets circular velocity at Sun to 220 km/s
#         """

#         default_disk = dict(m=7E10, a=3.5, b=0.14)
#         default_bulge = dict(m=1E10, c=1.1)
#         default_halo = dict(a=1., b=0.75, c=0.55,
#                             v_c=0.239225, r_s=30.,
#                             phi=0., theta=0., psi=0.)

#         for k, v in default_disk.items():
#             if k not in disk:
#                 disk[k] = v

#         for k, v in default_bulge.items():
#             if k not in bulge:
#                 bulge[k] = v

#         for k, v in default_halo.items():
#             if k not in halo:
#                 halo[k] = v

#         kwargs = dict()
#         kwargs["disk"] = MiyamotoNagaiPotential(units=units, **disk)
#         kwargs["bulge"] = HernquistPotential(units=units, **bulge)
#         kwargs["halo"] = LeeSutoTriaxialNFWPotential(units=units, **halo)
#         super(TriaxialMWPotential, self).__init__(**kwargs)
# --------------------------------------------------------------------
