import numpy as np
from ase.units import kB
from hiphive.io.logging import logger

logger = logger.getChild(__name__)


def generate_phonon_rattled_structures(atoms, fc2, n_structures, temperature):
    """
    Returns list of phonon-rattled configurations.

    Configurations are generated by superimposing harmonic phonon
    eigenmodes with random amplitudes and phase factors consistent
    with a certain temperature.

    Let :math:`\\boldsymbol{X}_{ai}` be the phonon modes indexed by atom
    :math:`a` and mode :math:`i`, :math:`\\omega_i` the phonon
    frequencies, and let :math:`0 < Q_i \\leq 1` and :math:`0 \\leq
    U_i < 1` be uniformly random numbers.  Then

    .. math::

        \\boldsymbol{R}_a
        = \\boldsymbol{R}^0_a
        + \\left<\\frac{k_B T}{m_a} \\right>^{1/2}
        \\sum_i \\frac{1}{\\omega_i} \\boldsymbol{X}_{ai}
        \\sqrt{-2 \\ln Q_i} \\cos(\\pi \\omega_i U_i)

    See: West and Estreicher PRL 96, 115504 (2006)

    Parameters
    ----------
    atoms : ase.Atoms
        prototype structure
    fc2 : numpy.ndarray
        second order force constant matrix, with shape `(3N, 3N)` or
        `(N, N, 3, 3)`. The conversion will be done internally if.
    n_structures : int
        number of structures to generate
    temperature : float
        temperature in Kelvin

    Returns
    -------
    list(ase.Atoms)
        generated structures
    """
    structures = []
    pr = _PhononRattler(atoms.get_masses(), fc2)
    for _ in range(n_structures):
        atoms_tmp = atoms.copy()
        pr(atoms_tmp, temperature)
        structures.append(atoms_tmp)
    return structures


def _phonon_rattle(m_a, T, w2_s, e_sai):
    """ Thermal excitation of phonon modes as described by West and
    Estreicher PRL 96, 115504 (2006).

    _s is a mode index
    _i is a Carteesian index
    _a is an atom index

    Parameters
    ----------
    m_a : numpy.ndarray
        masses (N,)
    T : float
        temperature in Kelvin
    w2_s : numpy.ndarray
        the squared frequencies from the igenvalue problem (3*N,)
    e_sai : numpy.ndarray
        polarizations (3*N, N, 3)

    Returns
    -------
    displacements : numpy.ndarray
        shape (N, 3)
    """
    n_modes = 3 * len(m_a)
    prefactor_a = np.sqrt(2 * kB * T / m_a)

    # The three modes closest to zero are assumed to be zero, i.e. acoustic sum rules are assumed
    argsort = np.argsort(np.abs(w2_s))
    w2_gamma = w2_s[argsort][:3]
    e_sai = e_sai[argsort][3:]
    w2_s = w2_s[argsort][3:]
    if np.any(np.abs(w2_gamma) > 1e-6):
        logger.warning('Acoustic sum rules not enforced, squared frequencies: {}'.format(w2_gamma))

    # treat imaginary modes as real
    if np.any(w2_s < 0):
        logger.warning('Imaginary modes present')
    w_s = np.sqrt(np.abs(w2_s))

    phases_s = np.random.uniform(0, 2 * np.pi, size=n_modes - 3)
    amplitudes_s = np.sqrt(-np.log(1 - np.random.random(n_modes - 3)))

    # prefactors are reshaped in order for numpy broadcast to work
    prefactor_a = prefactor_a.reshape(-1, 1)
    u_ai = prefactor_a * np.tensordot(amplitudes_s * np.cos(phases_s) / w_s, e_sai, (0, 0))
    return u_ai  # displacements


class _PhononRattler:
    """
    Class to be able to conveniently save modes and frequencies needed
    for phonon rattle.

    Parameters
    ----------
    masses : numpy.ndarray
        masses (N,)
    force_constants : numpy.ndarray
        second order force constant matrix, with shape `(3N, 3N)` or
        `(N, N, 3, 3)`. The conversion will be done internally if.
    """
    def __init__(self, masses, force_constants):
        n_atoms = len(masses)
        if len(force_constants.shape) == 4:  # assume shape = (n_atoms, n_atoms, 3, 3)
            force_constants = force_constants.transpose(0, 2, 1, 3)
            force_constants = force_constants.reshape(3 * n_atoms, 3 * n_atoms)
            # Now the fc should have shape = (n_atoms * 3, n_atoms * 3)
        # Construct the dynamical matrix
        inv_root_masses = (1 / np.sqrt(masses)).repeat(3)
        D = np.outer(inv_root_masses, inv_root_masses)
        D *= force_constants
        # find modes and energies
        w2_s, e_sai = np.linalg.eigh(D)
        # reshape to get atom index and Cartesian index separate
        e_sai = e_sai.T.reshape(-1, n_atoms, 3)
        self.w2_s = w2_s
        self.e_sai = e_sai
        self.masses = masses

    def __call__(self, atoms, T):
        """ rattle atoms by adding displacements

        Parameters
        ----------
        atoms : ase.Atoms
            Ideal structure to add displacements to.
        T : float
            temperature in Kelvin
        """
        u_ai = _phonon_rattle(self.masses, T, self.w2_s, self.e_sai)
        atoms.positions += u_ai
