import os.path

import numpy as np
import pytest

from aspire.basis import Coef, FFBBasis3D
from aspire.utils import grid_3d
from aspire.volume import AsymmetricVolume, Volume

from ._basis_util import UniversalBasisMixin, basis_params_3d, show_basis_params

DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")

test_bases = [FFBBasis3D(L, dtype=dtype) for L, dtype in basis_params_3d]


@pytest.mark.parametrize("basis", test_bases, ids=show_basis_params)
class TestFFBBasis3D(UniversalBasisMixin):
    def testFFBBasis3DIndices(self, basis):
        indices = basis.indices()

        assert np.allclose(
            indices["ells"],
            [
                0.0,
                0.0,
                0.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                1.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                2.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                3.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                4.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                5.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                6.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
                7.0,
            ],
        )

        assert np.allclose(
            indices["ms"],
            [
                0.0,
                0.0,
                0.0,
                -1.0,
                -1.0,
                -1.0,
                0.0,
                0.0,
                0.0,
                1.0,
                1.0,
                1.0,
                -2.0,
                -2.0,
                -2.0,
                -1.0,
                -1.0,
                -1.0,
                0.0,
                0.0,
                0.0,
                1.0,
                1.0,
                1.0,
                2.0,
                2.0,
                2.0,
                -3.0,
                -3.0,
                -2.0,
                -2.0,
                -1.0,
                -1.0,
                0.0,
                0.0,
                1.0,
                1.0,
                2.0,
                2.0,
                3.0,
                3.0,
                -4.0,
                -4.0,
                -3.0,
                -3.0,
                -2.0,
                -2.0,
                -1.0,
                -1.0,
                0.0,
                0.0,
                1.0,
                1.0,
                2.0,
                2.0,
                3.0,
                3.0,
                4.0,
                4.0,
                -5.0,
                -4.0,
                -3.0,
                -2.0,
                -1.0,
                0.0,
                1.0,
                2.0,
                3.0,
                4.0,
                5.0,
                -6.0,
                -5.0,
                -4.0,
                -3.0,
                -2.0,
                -1.0,
                0.0,
                1.0,
                2.0,
                3.0,
                4.0,
                5.0,
                6.0,
                -7.0,
                -6.0,
                -5.0,
                -4.0,
                -3.0,
                -2.0,
                -1.0,
                0.0,
                1.0,
                2.0,
                3.0,
                4.0,
                5.0,
                6.0,
                7.0,
            ],
        )

        assert np.allclose(
            indices["ks"],
            [
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                2.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                1.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0,
            ],
        )

    def testFFBBasis3DNorms(self, basis):
        norms = basis.norms()
        assert np.allclose(
            norms,
            [
                1.80063263231421,
                0.900316316157109,
                0.600210877438065,
                1.22885897287928,
                0.726196138639673,
                0.516613361675378,
                0.936477951517100,
                0.610605075148750,
                0.454495363516488,
                0.756963071176142,
                0.527618747123993,
                0.635005913075500,
                0.464867421846148,
                0.546574142892508,
                0.479450758110826,
                0.426739123914569,
            ],
        )

    def testFFBBasis3DEvaluate(self, basis):
        coefs = np.array(
            [
                1.07338590e-01,
                1.23690941e-01,
                6.44482039e-03,
                -5.40484306e-02,
                -4.85304586e-02,
                1.09852144e-02,
                3.87838396e-02,
                3.43796455e-02,
                -6.43284705e-03,
                -2.86677145e-02,
                -1.42313328e-02,
                -2.25684091e-03,
                -3.31840727e-02,
                -2.59706174e-03,
                -5.91919887e-04,
                -9.97433028e-03,
                9.19123928e-04,
                1.19891589e-03,
                7.49154982e-03,
                6.18865229e-03,
                -8.13265715e-04,
                -1.30715655e-02,
                -1.44160603e-02,
                2.90379956e-03,
                2.37066082e-02,
                4.88805735e-03,
                1.47870707e-03,
                7.63376018e-03,
                -5.60619559e-03,
                1.05165081e-02,
                3.30510143e-03,
                -3.48652120e-03,
                -4.23228797e-04,
                1.40484061e-02,
                1.42914291e-03,
                -1.28129504e-02,
                2.19868825e-03,
                -6.30835037e-03,
                1.18524223e-03,
                -2.97855052e-02,
                1.15491057e-03,
                -8.27947006e-03,
                3.45442781e-03,
                -4.72868856e-03,
                2.66615329e-03,
                -7.87929790e-03,
                8.84126590e-04,
                1.59402808e-03,
                -9.06854048e-05,
                -8.79119004e-03,
                1.76449039e-03,
                -1.36414673e-02,
                1.56793855e-03,
                1.44708445e-02,
                -2.55974802e-03,
                5.38506357e-03,
                -3.24188673e-03,
                4.81582945e-04,
                7.74260101e-05,
                5.48772082e-03,
                1.92058500e-03,
                -4.63538896e-03,
                -2.02735133e-03,
                3.67592386e-03,
                7.23486969e-04,
                1.81838422e-03,
                1.78793284e-03,
                -8.01474060e-03,
                -8.54007528e-03,
                1.96353845e-03,
                -2.16254252e-03,
                -3.64243996e-05,
                -2.27329863e-03,
                1.11424393e-03,
                -1.39389189e-03,
                2.57787159e-04,
                3.66918811e-03,
                1.31477774e-03,
                6.82220128e-04,
                1.41822851e-03,
                -1.89476924e-03,
                -6.43966255e-05,
                -7.87888465e-04,
                -6.99459279e-04,
                1.08918981e-03,
                2.25264584e-03,
                -1.43651015e-04,
                7.68377620e-04,
                5.05955256e-04,
                2.66936132e-06,
                2.24934884e-03,
                6.70529439e-04,
                4.81121742e-04,
                -6.40789745e-05,
                -3.35915672e-04,
                -7.98651783e-04,
                -9.82705453e-04,
                6.46337066e-05,
            ],
            dtype=basis.dtype,
        )

        result = Coef(basis, coefs).evaluate()

        ref = np.load(
            os.path.join(DATA_DIR, "ffbbasis3d_xcoef_out_8_8_8.npy")
        ).T  # RCOPT

        assert np.allclose(result, ref, atol=1e-2)

    def testFFBBasis3DEvaluate_t(self, basis):
        x = np.load(os.path.join(DATA_DIR, "ffbbasis3d_xcoef_in_8_8_8.npy")).T  # RCOPT
        x = x.astype(basis.dtype, copy=False)
        result = basis.evaluate_t(Volume(x))

        ref = np.load(os.path.join(DATA_DIR, "ffbbasis3d_vcoef_out_8_8_8.npy"))[..., 0]

        assert np.allclose(result, ref, atol=1e-2)

    def testFFBBasis3DExpand(self, basis):
        x = np.load(os.path.join(DATA_DIR, "ffbbasis3d_xcoef_in_8_8_8.npy")).T  # RCOPT
        x = x.astype(basis.dtype, copy=False)
        result = basis.expand(x)

        ref = np.load(os.path.join(DATA_DIR, "ffbbasis3d_vcoef_out_exp_8_8_8.npy"))[
            ..., 0
        ]

        assert np.allclose(result, ref, atol=1e-2)


params = [pytest.param(256, np.float32, marks=pytest.mark.expensive)]


@pytest.mark.parametrize(
    "L, dtype",
    params,
)
def testHighResFFBbasis3D(L, dtype):
    seed = 42
    basis = FFBBasis3D(L, dtype=dtype)
    vol = AsymmetricVolume(L=L, C=1, K=64, dtype=dtype, seed=seed).generate()

    # Round trip
    coef = basis.evaluate_t(vol)
    vol_ffb = basis.evaluate(coef)

    # Mask to compare inside sphere of radius 1.
    mask = grid_3d(L, normalized=True)["r"] < 1
    assert np.allclose(vol_ffb.asnumpy()[0][mask], vol.asnumpy()[0][mask], atol=1e-3)
