import pickle

import astropy.units as u
import numpy as np
import pytest

from gala.potential import (
    ConstantRotatingFrame,
    Hamiltonian,
    KeplerPotential,
    StaticFrame,
)
from gala.units import galactic, solarsystem


def test_init():
    p = KeplerPotential(m=1.0)
    f = StaticFrame()
    H = Hamiltonian(potential=p, frame=f)
    H2 = Hamiltonian(H)
    assert H2.potential is H.potential

    str_ = repr(H)
    assert "KeplerPotential" in str_
    assert "StaticFrame" in str_

    p = KeplerPotential(m=1.0, units=solarsystem)
    f = StaticFrame(units=solarsystem)
    H = Hamiltonian(potential=p, frame=f)
    H = Hamiltonian(potential=p)

    p = KeplerPotential(m=1.0)
    f = StaticFrame(galactic)
    with pytest.raises(ValueError):
        H = Hamiltonian(potential=p, frame=f)

    p = KeplerPotential(m=1.0, units=solarsystem)
    f = StaticFrame()
    with pytest.raises(ValueError):
        H = Hamiltonian(potential=p, frame=f)

    p = KeplerPotential(m=1.0, units=solarsystem)
    f = ConstantRotatingFrame(Omega=1.0 / u.yr, units=solarsystem)
    with pytest.raises(ValueError):
        H = Hamiltonian(potential=p, frame=f)


def test_pickle(tmpdir):
    filename = tmpdir / "hamil.pkl"

    p = KeplerPotential(m=1.0, units=solarsystem)

    for fr in [
        StaticFrame(units=solarsystem),
        ConstantRotatingFrame(Omega=[0, 0, 1] / u.yr, units=solarsystem),
    ]:
        H = Hamiltonian(potential=p, frame=fr)

        with open(filename, "wb") as f:
            pickle.dump(H, f)

        with open(filename, "rb") as f:
            H2 = pickle.load(f)


def test_regression_integrate_orbit_shape():
    """
    Test that integrate_orbit validates input shape correctly.
    """
    p = KeplerPotential(m=1.0)
    f = StaticFrame()
    H = Hamiltonian(potential=p, frame=f)

    with pytest.raises(ValueError):
        H.integrate_orbit(np.zeros((5, 6)), t=np.linspace(0, 1, 128))
