# -*- coding: utf-8 -*-
"""
Created on Mon 18 Mar 2024 15:49:00

@author: Chris Nill
@author: Lukas Ahlheit
"""
import warnings
from multiprocessing import Pool

# %%

import scipy.special as scs
import numpy as np
import arc
import pint
from functools import lru_cache, partial

from scipy.integrate import quad_vec, quad
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

u = pint.get_application_registry()

hbar = u.hbar / u.rad
atom = arc.Rubidium87(cpp_numerov=True)
M = atom.mass * u.kg


def double_complex_integral(func, x0, x1, rho_0, rho_1, **kwargs):
    """
    Calculate the double integral of a function func over the area defined by x0, x1, rho_0, rho_1
    :param func: function to integrate
    :param x0: lower bound of x
    :param x1: upper bound of x
    :param rho_0: lower bound of rho
    :param rho_1: upper bound of rho
    :return: integral value
    """

    def f(x):
        return quad(lambda rho: func(x, rho, **kwargs), rho_0, rho_1, epsrel=1e-4, complex_func=True)[0]

    return quad(f, x0, x1, epsrel=1e-4, complex_func=True)[0]


@lru_cache(maxsize=65536)
def d_fun(fun, x):
    """
    Calculates the derivative of a function at position x.
    """
    h = 1e-5 * x.u
    return (fun(x + h) - fun(x - h)) / (2 * h)


@lru_cache(maxsize=65536)
def d2_fun(fun, x):
    """
    Calculates the 2nd derivative of a function at position x.
    """
    h = 1e-5 * x.u
    return (fun(x + h) - 2 * fun(x) + fun(x - h)) / h ** 2


@lru_cache(maxsize=65536)
def radial_wavefunction(n, l, j):
    """
    Calculates the radial wavefunction at radii r
    :param n: principal quantum number n
    :param l: quantum angular momentum l (0 for S state)
    :param j: quantum angular momentum j (coupled l+s, 1/2 in our case)
    :return tuple of radii r and corresponding value of the wavefunction Rr
    """
    inner_limit = atom.alphaC ** (1 / 3.0)
    outer_limit = 2.0 * n * (n + 15.0)
    r, Rr = atom.radialWavefunction(l, 0.5, j,
                                    atom.getEnergy(n, l, j) / 27.211,  # in atomic units (Hartree)
                                    inner_limit,
                                    outer_limit,
                                    0.1)
    return r * u.atomic_unit_of_length, Rr / np.sqrt(1 * u.atomic_unit_of_length)


class ColdGas:
    def __init__(self, rydberg_n,
                 waist_12=5 * u.micrometer, P_12=0.5 * u.mW, X_0_12=0 * u.um, waist_23=40 * u.micrometer,
                 P_23=0.5 * u.mW, X_0_23=0 * u.um, delta_12=1 * u.GHz, excitation_pulse_width=0.25 * u.us,
                 waist_trap_p=24 * u.micrometer, waist_trap_m=24 * u.micrometer, X_0_p=600 * u.micrometer,
                 X_0_m=0 * u.micrometer, P_p=17 * u.mW, P_m=0.75 * 17 * u.mW, cloud_length=20 * u.micrometer,
                 T_atoms=4e-6 * u.kelvin, room_temp=293 * u.K, n_of_x=1000, truncation=9, multiprocessing=False,
                 no_hash=False, rydberg_only_free_electron=True):
        """
        Create a ColdGas experiment with the given parameters. Parameters can be changed later,
        the cache will be updated.

        :param rydberg_n: principal quantum number n of the Rydberg state
        :param waist_12: waist of the 1-2 laser
        :param P_12: power of the 1-2 laser
        :param X_0_12: position of the focus of the 1-2 laser
        :param excitation_pulse_width: duration of the excitation laser pulse (standard deviation)
        :param waist_23: waist of the 2-3 laser
        :param P_23: power of the 2-3 laser
        :param X_0_23: position of the focus of the 2-3 laser
        :param delta_12: detuning of the 1-2 laser
        :param waist_trap_p: waist of the incident trapping laser at the center of the cloud
        :param waist_trap_m: waist of the reflected trapping laser at the center of the cloud
        :param X_0_p: position of the focus of the incident beam
        :param X_0_m: position of the focus of the reflected beam
        :param P_p: laser power of the incident beam
        :param P_m: laser power of the reflected beam
        :param cloud_length: length of the cold atom cloud
        :param T_atoms: temperature of the atoms
        :param room_temp: room temperature
        :param n_of_x: number of steps for the integration of the Mathieu functions
        :param truncation: maximum number of quasi-bound states for the Mathieu functions. Raises an error if too low.
        :param multiprocessing: flag to enable multiprocessing
        :param no_hash: flag to disable hashing of the object (BEWARE: this will disable the check for equality of two
                ColdGas objects, which can lead to unexpected behavior)
        :param rydberg_only_free_electron: flag to disable the calculation of the dynamic polarizability of the Rydberg
                state. Only the free electron polarizability will be used for all n. This speeds up calculations.
        """
        self.rydberg_n = rydberg_n
        self._waist_12 = waist_12
        self.P_12 = P_12
        self.X_0_12 = X_0_12
        self.excitation_pulse_width = excitation_pulse_width
        self._waist_23 = waist_23
        self.P_23 = P_23
        self.X_0_23 = X_0_23
        self.delta_12 = delta_12
        self._waist_trap_p = waist_trap_p
        self._waist_trap_m = waist_trap_m
        self.P_p = P_p
        self.P_m = P_m
        self.X_0_p = X_0_p
        self.X_0_m = X_0_m
        self.cloud_length = cloud_length
        self.T_atoms = T_atoms
        self.room_temp = room_temp
        self.n_of_x = n_of_x
        self.truncation = truncation
        self.__multiprocessing = multiprocessing
        self.__no_hash = no_hash
        self.__rydberg_only_free_electron = rydberg_only_free_electron
        if no_hash:
            self.__hash = hash((self.rydberg_n,
                                self._waist_12,
                                self.P_12,
                                self.X_0_12,
                                self.excitation_pulse_width,
                                self._waist_23,
                                self.P_23,
                                self.X_0_23,
                                self.delta_12,
                                self._waist_trap_p,
                                self._waist_trap_m,
                                self.P_p,
                                self.P_m,
                                self.X_0_p,
                                self.X_0_m,
                                self.cloud_length,
                                self.T_atoms,
                                self.room_temp,
                                self.n_of_x,
                                self.truncation,
                                self.__rydberg_only_free_electron,
                                ))

    def __eq__(self, other):
        """
        Check if two ColdGas objects are equal
        :param other: other ColdGas object
        :return: boolean
        """
        return self.rydberg_n == other.rydberg_n and \
            self._waist_12 == other._waist_12 and \
            self.P_12 == other.P_12 and \
            self.X_0_12 == other.X_0_12 and \
            self.excitation_pulse_width == other.excitation_pulse_width and \
            self._waist_23 == other._waist_23 and \
            self.P_23 == other.P_23 and \
            self.X_0_23 == other.X_0_23 and \
            self.delta_12 == other.delta_12 and \
            self._waist_trap_p == other._waist_trap_p and \
            self._waist_trap_m == other._waist_trap_m and \
            self.P_p == other.P_p and \
            self.P_m == other.P_m and \
            self.X_0_p == other.X_0_p and \
            self.X_0_m == other.X_0_m and \
            self.cloud_length == other.cloud_length and \
            self.T_atoms == other.T_atoms and \
            self.room_temp == other.room_temp and \
            self.n_of_x == other.n_of_x and \
            self.truncation == other.truncation and \
            self.__multiprocessing == other.__multiprocessing and \
            self.__rydberg_only_free_electron == other.__rydberg_only_free_electron

    def __hash__(self):
        """
        Hash the ColdGas object, needed for lru cache
        :return:
        """
        if self.__no_hash:
            return self.__hash
        else:
            return hash((self.rydberg_n,
                         self._waist_12,
                         self.P_12,
                         self.X_0_12,
                         self.excitation_pulse_width,
                         self._waist_23,
                         self.P_23,
                         self.X_0_23,
                         self.delta_12,
                         self._waist_trap_p,
                         self._waist_trap_m,
                         self.P_p,
                         self.P_m,
                         self.X_0_p,
                         self.X_0_m,
                         self.cloud_length,
                         self.T_atoms,
                         self.room_temp,
                         self.n_of_x,
                         self.truncation,
                         self.__rydberg_only_free_electron,
                         ))

    @staticmethod
    def lambda_12():
        return atom.getTransitionWavelength(5, 0, 1 / 2, 5, 1, 3 / 2) * u.m

    #    @lru_cache(maxsize=65536)
    def waist_12(self, X):
        """
        Calculates the waist of the 1-2 laser beam
        :param X: position in the trap along beam axis relative to the center of the cloud
        :return: waist in length units
        """
        X_r = np.pi * self._waist_12 ** 2 / self.lambda_12()  # Rayleigh length
        waist = self._waist_12 * np.sqrt(1 + (X - self.X_0_12) ** 2 / X_r ** 2)
        return waist

    #    @lru_cache(maxsize=65536)
    def e_field_12(self, X, rho):
        """
        Calculate the electric field amplitude of the 1-2 laser beam
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: electric field
        """
        E_0 = np.sqrt(self.P_12 * 16 * u.c * u.mu_0 / (np.pi * self.waist_12(0 * u.um) ** 2))
        
        return E_0 * self.waist_12(0 * u.um) / self.waist_12(X) * np.exp(-rho ** 2 / self.waist_12(X) ** 2)

    #    @lru_cache(maxsize=65536)
    def rabi_12(self, X, rho):
        """
        Calculate the Rabi frequency for the retrieval pulse 1-2.
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: Rabi frequency
        """
        d = atom.getDipoleMatrixElementHFS(5, 0, 1 / 2, 2, 2, 5, 1, 3 / 2, 3, 3, 1) * u.a0 * u.e
        return self.e_field_12(X, rho) * d / hbar

    @lru_cache(maxsize=65536)
    def lambda_23(self):
        return atom.getTransitionWavelength(5, 1, 3 / 2, self.rydberg_n, 0, 1 / 2) * u.m

    #    @lru_cache(maxsize=65536)
    def waist_23(self, X):
        """
        Calculates the waist of the 2-3 laser beam
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :return: waist in length units
        """
        X_r = np.pi * self._waist_23 ** 2 / self.lambda_23()  # Rayleigh length
        waist = self._waist_23 * np.sqrt(1 + (X - self.X_0_23) ** 2 / X_r ** 2)
        return waist

    #    @lru_cache(maxsize=65536)
    def e_field_23(self, X, rho):
        """
        Calculate the electric field amplitude of the 2-3 laser beam
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: electric field
        """
        E_0 = np.sqrt(self.P_23 * 16 * u.c * u.mu_0 / (np.pi * self.waist_23(0 * u.um) ** 2))
        
        return E_0 * self.waist_23(0 * u.um) / self.waist_23(X) * np.exp(-rho ** 2 / self.waist_23(X) ** 2)

    #    @lru_cache(maxsize=65536)
    def rabi_23(self, X, rho):
        """
        Calculate the Rabi frequency for the retrieval pulse 2-3.
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: Rabi frequency
        """
        d = atom.getDipoleMatrixElementHFS(5, 1, 3 / 2, 3, 3, 70, 0, 1 / 2, 2, 2, -1) * u.a0 * u.e
        return self.e_field_23(X, rho) * d / hbar

    def rabi_31(self, X, rho):
        """
        Calculate the Rabi frequency for the excitation pulse 3-1.
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: Rabi frequency
        """
        return self.rabi_12(X, rho) * self.rabi_23(X, rho) / (2 * self.delta_12)

    @lru_cache(maxsize=65536)
    def lambda_trap_laser(self):
        return atom.getTransitionWavelength(6, 1, 3 / 2, self.rydberg_n, 0, 1 / 2) * u.m

    @lru_cache(maxsize=65536)
    def k_trap_laser(self):
        """
        Calculate the wave number of the trapping laser
        :return: wave number in 1/length
        """
        return 2 * u.pi / self.lambda_trap_laser()

    @lru_cache(maxsize=65536)
    def omega_trap_laser(self):
        """
        Calculate the angular frequency of the trapping laser
        :return:
        """
        return u.c * 2 * u.pi / self.lambda_trap_laser()

    @lru_cache(maxsize=65536)
    def trap_power_ratio(self):
        """
        Calculate the ratio of the reflected and incident power of the trapping laser
        :return: dimensionless number
        """
        return self.P_m / self.P_p

    @lru_cache(maxsize=65536)
    def effective_trap_power(self):
        """
        Calculate the effective trap power
        :return: effective trap power
        """
        return np.sqrt(self.P_p * self.P_m)

    @lru_cache(maxsize=65536)
    def rabi_trap(self):
        """
        Calculate the Rabi frequency for the trapping laser
        """
        d = atom.getDipoleMatrixElementHFS(self.rydberg_n, 0, 1 / 2, 2, 2, 6, 1, 3 / 2, 3, 3, 1) * u.a0 * u.e
        
        return (self.e_field_trap_p() + self.e_field_trap_m()) / 2 * d / hbar

    @lru_cache(maxsize=65536)
    def alpha_f(self):
        """
        Calculate the free electron polarizability including the lattice laser wavelength
        :return: polarizability
        """
        if self.__rydberg_only_free_electron:
            alpha_f = -1 * u.e ** 2 / u.m_e / self.omega_trap_laser() ** 2
            return alpha_f

        else:
            dp = arc.DynamicPolarizability(atom, self.rydberg_n, 0, 1 / 2)
            dp.defineBasis(atom.groundStateN, self.rydberg_n + 40)
            lamda = np.round(self.lambda_trap_laser().to(u.nm).magnitude, 4) * 1e-9

            alpha0, alpha1, _, _, _, _ = dp.getPolarizability(lamda, units="SI",
                                                              accountForStateLifetime=True)
            if alpha1 is None:
                alpha1 = 0.0
                warnings.warn(rf"(n={self.rydberg_n}) No vector polarizability found. Assuming zero.")
            alpha1 = alpha1 * u.Hz * u.m ** 2 / u.V ** 2 * u.planck_constant
            if alpha0 is None:
                warnings.warn(
                    fr"(n={self.rydberg_n}) No scalar polarizability found. Assuming free electron polarizability.")
                alpha0 = -1 * u.e ** 2 / u.m_e / self.omega_trap_laser() ** 2
            else:
                alpha0 = alpha0 * u.Hz * u.m ** 2 / u.V ** 2 * u.planck_constant
            return alpha0 + alpha1

    @lru_cache(maxsize=65536)
    def alpha_g(self):
        """
        Calculate the ground state polarizability of Rb 5S_1/2 including the vector polarizability
        (1e-2 order correction due to circular polarization)
        :return: polarizability
        """
        dp = arc.DynamicPolarizability(atom, 5, 0, 1 / 2)
        dp.defineBasis(atom.groundStateN, 50)
        alpha0, alpha1, _, _, _, _ = dp.getPolarizability(self.lambda_trap_laser().to(u.m).magnitude, units="SI",
                                                          accountForStateLifetime=True)
        return (alpha0 + alpha1) * u.Hz * u.m ** 2 / u.V ** 2 * u.planck_constant

    def waist_trap_p(self, X):
        """
        Calculates the waist of the lattice laser beam
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :return: waist in length units
        """
        X_r = np.pi * self._waist_trap_p ** 2 / self.lambda_trap_laser()  # Rayleigh length
        waist = self._waist_trap_p * np.sqrt(1 + (X - self.X_0_p) ** 2 / X_r ** 2)
        return waist

    def waist_trap_m(self, X):
        """
        Calculates the waist of the lattice laser beam
        :param X: (ndarray) position in the trap along beam axis relative to the center of the cloud
        :return: waist in length units
        """
        X_r = np.pi * self._waist_trap_m ** 2 / self.lambda_trap_laser()  # Rayleigh length
        waist = self._waist_trap_m * np.sqrt(1 + (X - self.X_0_m) ** 2 / X_r ** 2)
        return waist

    @lru_cache(maxsize=65536)
    def e_field_trap_p(self):
        """
        Calculate the electric field amplitude of the incident trapping laser beam
        :return: electric field
        """
        return np.sqrt(self.P_p * 16 * u.c * u.mu_0 / (np.pi * self.waist_trap_p(0) ** 2))

    @lru_cache(maxsize=65536)
    def e_field_trap_m(self):
        """
        Calculate the electric field amplitude of the reflected trapping laser beam
        :return: electric field
        """
        return np.sqrt(self.P_m * 16 * u.c * u.mu_0 / (np.pi * self.waist_trap_m(0) ** 2))

    #    @lru_cache(maxsize=65536)
    def A_trap_plus(self, rho, X):
        """
        Calculates the amplitude of the vector potential A of the incident beam.
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        :return: A_plus(rho,X)
        """
        A_plus = self.e_field_trap_p() * self.waist_trap_p(0 * u.um) / self.waist_trap_p(X) * np.exp(
            -rho ** 2 / self.waist_trap_p(X) ** 2)
        if not A_plus.is_compatible_with(u.volt / u.meter):
            raise pint.errors.DimensionalityError(A_plus, u.volt / u.meter)
        return A_plus

    #    @lru_cache(maxsize=65536)
    def A_trap_minus(self, rho, X):
        """
        Calculates the amplitude of the vector potential A of the reflected beam.
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        :return: A_minus(rho,X)
        """
        return self.e_field_trap_m() * self.waist_trap_m(0 * u.um) / self.waist_trap_m(X) * np.exp(
            -rho ** 2 / self.waist_trap_m(X) ** 2)

    def E_trap_squared(self, X, rho):
        """
        Calculates the absolute value of the electric field of the trapping laser beam squared.
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        """
        return 1 / 8 * (4 * self.A_trap_plus(rho, X) * self.A_trap_minus(rho, X) * np.cos(self.k_trap_laser() * X) ** 2
                        + (self.A_trap_plus(rho=rho, X=X) + self.A_trap_minus(rho=rho, X=X)) ** 2)

    @lru_cache(maxsize=65536)
    def trap_depth(self, X):
        """
        Calculate the trap depth at the center of the cloud
        :return: trap depth in units of energy
        """
        u0 = 1 / 4 * self.alpha_g() * 16 * u.mu_0 * u.c * self.effective_trap_power() / u.pi / (
                self.waist_trap_m(X) * self.waist_trap_p(X))

        if not u0.is_compatible_with(u.J):
            raise pint.errors.DimensionalityError(u0, u.J)
        return u0

    @lru_cache(maxsize=65536)
    def reduced_dipole_matrix_element_n(self):
        """
        Calculates the reduced dipole matrix element D_n=<nS|D|6P_3/2>.
        :return: Reduced dipole matrix element
        """
        matrix_element = atom.getReducedMatrixElementJ(6, 1, 3 / 2, self.rydberg_n, 0, 1 / 2) * u.e * u.a0
        return matrix_element

    @lru_cache(maxsize=65536)
    def intensity(self, X, rho):
        """
        Calculates the intensity of the trapping laser beam at position X and transversal distance rho (eq. 18).
        :param X: axial position in the trap
        :param rho: transversal distance rho relative to the beam axis of the trap
        :return: Intensity in units of power per area
        """
        _I = self.waist_trap_p(X) * self.waist_trap_m(X) / 2 * (
                (np.sqrt(1 / self.trap_power_ratio()) * np.exp(
                    -2 * rho ** 2 / self.waist_trap_p(X) ** 2) / self.waist_trap_p(X) ** 2)
                + (np.sqrt(self.trap_power_ratio()) * np.exp(
            -2 * rho ** 2 / self.waist_trap_m(X) ** 2) / self.waist_trap_m(X) ** 2)
        )
        return _I

    @lru_cache(maxsize=65536)
    def theta_n(self):
        """
        Calculates landscape factor to consider lattice potential that effects Rydberg polarizability.
        It is given as the expectation value of cos(2kx) of the Rydberg state wave function, namely <nS|cos(2kx)|nS>.
        From Lampen paper PHYSICAL REVIEW A 98, 033411 (2018), page 7, eq. on left side above "2. Reduced dipole moment matrix elements"

        :param l: angular momentum l (S-state)
        :param j: coupled angular momentum j (S+1/2), here 1/2
        :param lambda_l: wavelength of the trapping laser
        :return: value of theta_n for given parameters
        """
        theta = 0
        rSteps = []
        r, Rr = radial_wavefunction(n=self.rydberg_n, l=0, j=1 / 2)

        for i in range(len(r) - 1):
            rSteps.append((r[i + 1] - r[i]))

        for i in range(len(r) - 1):
            bessel_argument = (2 * self.k_trap_laser() * r[i]).to_base_units()
            if not bessel_argument.is_compatible_with(u.pi):
                raise pint.errors.DimensionalityError(bessel_argument, u.pi)
            theta += (Rr[i]) ** 2 * scs.spherical_jn(0, bessel_argument.magnitude, derivative=False) * rSteps[i]
        return theta

    #    @lru_cache(maxsize=65536)
    def ground_state_lattice_potential(self, X, rho):
        """
        Calculates the lattice potential (the contribution along the beam direction) for the ground state atom.
        For the definition see Eq. 9a of the Kuzmich2018 paper.

        :param trap_wavelength: wavelength of the trapping laser
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        :return: potential (light-shift) of a ground state atom experienced by the trap
        """
        U = (-1 / 4 * self.alpha_g()
             * self.A_trap_plus(rho=rho, X=X) * self.A_trap_minus(rho=rho, X=X)
             * np.cos(self.k_trap_laser() * X) ** 2)

        if not U.is_compatible_with(u.J):
            raise pint.errors.DimensionalityError(U, u.J)
        return U

    # @lru_cache(maxsize=65536)
    def rydberg_state_lattice_potential(self, trap_detuning, X, rho):
        """
        Calculates the lattice potential (the contribution along the beam direction) for the Rydberg state atom.
        For the definition see Eq. 9b of the Kuzmich2018 paper.
        Circular polarization is assumed.

        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        :return: potential (light-shift) of a ground state atom experienced by the trap
        """

        return -1 / 4 * self.A_trap_plus(rho=rho, X=X) * self.A_trap_minus(rho=rho, X=X) * np.cos(
            self.k_trap_laser() * X) ** 2 * (
                self.reduced_dipole_matrix_element_n() ** 2 / (4 * hbar * trap_detuning)
                - np.abs(self.alpha_f()) * self.theta_n()
        )

    # @lru_cache(maxsize=65536)
    def ground_state_non_lattice_potential(self, X, rho):
        """
        Calculates the non-lattice potential (the contribution independent on the position along the beam direction)
        for the ground state atom.
        For the definition see Eq. 10a of the Kuzmich2018 paper.
        Circular polarization is assumed.

        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: potential (light-shift) of a ground state atom experienced by the trap
        """

        return -self.alpha_g() / 16 * (self.A_trap_plus(rho=rho, X=X) - self.A_trap_minus(rho=rho, X=X)) ** 2

    @lru_cache(maxsize=65536)
    def ground_state_running_wave_potential(self, X, rho):
        """
        Calculates the non-lattice potential (the contribution independent on the position along the beam direction)
        for the ground state atom in a running wave.
        For the definition see Eq. 10a of the Kuzmich2018 paper.
        Circular polarization is assumed.

        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :param X: (ndarray) position X along the beam direction in the trap relative to the center of the atomic cloud
        :return: potential (light-shift) of a ground state atom experienced by the trap
        """

        return -self.alpha_g() / 16 * self.A_trap_plus(rho=rho, X=X) ** 2 * 4

    # @lru_cache(maxsize=65536)
    def rydberg_state_non_lattice_potential(self, trap_detuning, X, rho):
        """
        Calculates the non-lattice potential (the contribution independent on the position along the beam direction)
        for the Rydberg state atom.
        For the definition see Eq. 9b of the Kuzmich2018 paper.

        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :param X: position X along the beam direction in the trap relative to the center of the atomic cloud
        :param rho: transversal distance rho relative to the beam axis of the trap
        :return: potential (light-shift) of a ground state atom experienced by the trap
        """
        A_m = self.A_trap_minus(rho=rho, X=X)
        A_p = self.A_trap_plus(rho=rho, X=X)

        return -(self.reduced_dipole_matrix_element_n() ** 2 / (64 * hbar * trap_detuning) * (A_p - A_m) ** 2) \
            + np.abs(self.alpha_f()) / 16 * (2 * A_p * A_m * (1 - self.theta_n()) + (A_p - A_m) ** 2)

    @lru_cache(maxsize=65536)
    def rydberg_state_running_wave_potential(self, trap_detuning, X, rho):
        """
        Calculates the non-lattice potential (the contribution independent on the position along the beam direction)
        for the Rydberg state atom.
        For the definition see Eq. 9b of the Kuzmich2018 paper.

        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :param X: position X along the beam direction in the trap relative to the center of the atomic cloud
        :param rho: transversal distance rho relative to the beam axis of the trap
        :return: potential (light-shift) of a ground state atom experienced by the trap
        """

        reduced_dipole_element_n = self.reduced_dipole_matrix_element_n()
        U = -(reduced_dipole_element_n ** 2 / (64 * hbar * trap_detuning)
              * self.A_trap_plus(rho=rho, X=X) ** 2) * 4 \
            + np.abs(self.alpha_f()) / 16 * (
                    self.A_trap_plus(rho=rho, X=X) ** 2 * 4
            )
        return U

    @lru_cache(maxsize=65536)
    def calc_magic_detuning(self):
        """
        Calculates the magic detuning of a given Rydberg state according to estimation formula.
        Circular polarization is assumed.
        :return: magic detuning Delta
        """
        magic_delta = self.reduced_dipole_matrix_element_n() ** 2 / (
                4 * hbar * (self.alpha_g() + np.abs(self.alpha_f()) * self.theta_n()))
        return magic_delta.to('pi*GHz')

    @lru_cache(maxsize=65536)
    def calc_magic_detuning_running_wave(self):
        """
        Calculates the magic detuning of a given Rydberg state according to estimation formula.
        Circular polarization is assumed.
        :return: magic detuning Delta
        """
        magic_delta = self.reduced_dipole_matrix_element_n() ** 2 / (
                4 * hbar * (self.alpha_g() + np.abs(self.alpha_f())))
        return magic_delta.to('pi*GHz')

    @lru_cache(maxsize=65536)
    def _get_mathieu_q(self, X):
        q = -1 / 4 * self.trap_depth(X) / ((hbar * self.k_trap_laser()) ** 2 / (2 * M))
        q = q.to_base_units()
        if not q.dimensionless:
            raise pint.errors.DimensionalityError(q, 1)
        return q.magnitude

    @property
    def __x_list_um(self):
        return np.linspace(-1 * u.pi * u.rad / self.k_trap_laser() / 2, u.pi * u.rad / self.k_trap_laser() / 2,
                           self.n_of_x).to('um').magnitude

    @lru_cache(maxsize=65536)
    def ground_state_quantized_motion(self, X):
        """
        Calculates the quasi bound state energies of the QMO for a ground state atom with given parameters.
        :param truncation:
        :param n_of_x:
        :param trap_depth: amplitude of the trapping potential
        :param trap_wavelength: wavelength of the trapping laser
        :return: list of quasi bound state energies
        """
        q = self._get_mathieu_q(X)

        a_even = scs.mathieu_a(np.arange(0, 2 * self.truncation, 2), q)
        a_odd = scs.mathieu_b(np.arange(2, 2 * self.truncation, 2), q)

        """
        We know from the Hamiltonian that 2k_L X has to be 2pi periodic. This is only justified if we choose for mathieu_cem only even m.
        We assume that the trapping should be symmetric w.r. to the center of the trap. Because of this we set the amplitude of mathieu_sem to 0.
        """
        k_um = self.k_trap_laser().to('1/um').magnitude
        omega_gq_even = (a_even * (hbar * self.k_trap_laser()) ** 2 / M - self.trap_depth(X)) / 2
        omega_gq_odd = (a_odd * (hbar * self.k_trap_laser()) ** 2 / M - self.trap_depth(X)) / 2

        psis = np.zeros([2 * self.truncation - 1, self.n_of_x])
        omegas = np.zeros([2 * self.truncation - 1]) * (1 * u.k_B * u.uK)
        i = 0
        for m, _ in enumerate(omega_gq_even):
            psis[i] = scs.mathieu_cem(2 * m, q, k_um * self.__x_list_um * 360 / 2 / np.pi)[0]
            omegas[i] = omega_gq_even[m]
            i += 1
        for m, _ in enumerate(omega_gq_odd):
            psis[i] = scs.mathieu_sem(2 * m + 2, q, k_um * self.__x_list_um * 360 / 2 / np.pi)[0]
            omegas[i] = omega_gq_odd[m]
            i += 1

        for i in range(psis.shape[0]):
            psis[i] /= np.sqrt(np.sum(psis[i] ** 2) / self.n_of_x)

        return {
            'num_of_solutions': len(omegas),
            'num_of_bound_solutions': len(np.where(omegas < 0)[0]),
            'eigenenergies': omegas,
            'eigenfunctions': psis,
            'x': self.__x_list_um * u.um,
        }

    @lru_cache(maxsize=65536)
    def rydberg_state_quantized_motion(self, trap_detuning, X):
        """
        Calculates the quasi bound state energies of the QMO for a Rydberg atom with given parameters.
        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :param X: position X along the beam direction in the trap relative to the center of the atomic cloud
        :return: (list of quasi bound state energies, index m to the corresponding quasi bound state energy)
        """

        Dn = self.reduced_dipole_matrix_element_n()

        q = self._get_mathieu_q(X)
        q = q / self.alpha_g() * (Dn ** 2 / (4 * hbar * trap_detuning) - np.abs(self.alpha_f()) * self.theta_n())
        q = q.to_base_units().magnitude

        a_even = scs.mathieu_a(np.arange(0, 2 * self.truncation, 2), q)
        a_odd = scs.mathieu_b(np.arange(2, 2 * self.truncation, 2), q)

        """
        We know from the Hamiltonian that 2k_L X has to be 2pi periodic. This is only justified if we choose for mathieu_cem only even m.
        We assume that the trapping should be symmetric w.r. to the center of the trap. Because of this we set the amplitude of mathieu_sem to 0.
        """
        k_um = self.k_trap_laser().to('1/um').magnitude

        omega_gq_even = (a_even * (hbar * self.k_trap_laser()) ** 2 / M - self.trap_depth(X) / self.alpha_g() * (
                Dn ** 2 / (4 * hbar * trap_detuning) - np.abs(self.alpha_f()) * self.theta_n())) / 2
        omega_gq_even = omega_gq_even.to('k_B microkelvin')
        omega_gq_odd = (a_odd * (hbar * self.k_trap_laser()) ** 2 / M - self.trap_depth(X) / self.alpha_g() * (
                Dn ** 2 / (4 * hbar * trap_detuning) - np.abs(self.alpha_f()) * self.theta_n())) / 2
        omega_gq_odd = omega_gq_odd.to('k_B microkelvin')

        psis = np.zeros([2 * self.truncation - 1, self.n_of_x])
        omegas = np.zeros([2 * self.truncation - 1]) * (1 * u.J).to('k_B microkelvin')
        i = 0
        for m, _ in enumerate(omega_gq_even):
            psis[i] = scs.mathieu_cem(2 * m, q, k_um * self.__x_list_um * 360 / 2 / np.pi)[0]
            omegas[i] = omega_gq_even[m]
            i += 1
        for m, _ in enumerate(omega_gq_odd):
            psis[i] = scs.mathieu_sem(2 * m + 2, q, k_um * self.__x_list_um * 360 / 2 / np.pi)[0]
            omegas[i] = omega_gq_odd[m]
            i += 1

        if np.amax(omegas).magnitude < 0:
            raise RuntimeError(
                'Not all bound solutions are included! Increase the truncation of the Mathieu functions.')

        for i in range(psis.shape[0]):
            psis[i] /= np.sqrt(np.sum(psis[i] ** 2) / self.n_of_x)

        return {
            'num_of_solutions': len(omegas),
            'num_of_bound_solutions': len(np.where(omegas < 0)[0]),
            'eigenenergies': omegas,
            'eigenfunctions': psis,
            'x': self.__x_list_um * u.um,
        }

    @lru_cache
    def __k_recoil_ium(self):
        """
        Calculate the recoil wave number in inverse micrometers
        """
        return (2 * u.pi * u.rad / self.lambda_12() - 2 * u.pi * u.rad / self.lambda_23()).to('1/um').magnitude

    @lru_cache(maxsize=65536)
    def get_matrix_element_q(self, trap_detuning, m_ground, m_rydberg, X):
        """

        :param trap_depth:
        :param lambda_1:
        :param lambda_2:
        :param trap_wavelength:
        :param m_ground:
        :param m_rydberg:
        :return:
        """

        gs_modes = self.ground_state_quantized_motion(X)
        rs_modes = self.rydberg_state_quantized_motion(trap_detuning=trap_detuning, X=X)

        mat_el = np.sum(
            np.conjugate(rs_modes['eigenfunctions'][m_rydberg]) * np.exp(
                1j * self.__k_recoil_ium() * self.__x_list_um) *
            gs_modes['eigenfunctions'][m_ground]) / len(self.__x_list_um)
        return mat_el

    @lru_cache(maxsize=65536)
    def thermal_distribution_rho_ground(self, X):
        ms_ground = self.ground_state_quantized_motion(X=X)
        rho = np.exp(-ms_ground['eigenenergies'] / (u.k_B * self.T_atoms)) / np.sum(
            np.exp(-ms_ground['eigenenergies'] / (u.k_B * self.T_atoms)))
        if __debug__ and np.abs(np.sum(rho) - 1) > 1e-7:
            raise ValueError('thermal state not normalized')
        return rho.to('dimensionless').magnitude

    @lru_cache(maxsize=65536)
    def calc_lattice_signal(self, time, trap_detuning):
        X = 0
        gs_modes = self.ground_state_quantized_motion(X)
        rs_modes = self.rydberg_state_quantized_motion(trap_detuning=trap_detuning, X=X)

        m_g_list = np.arange(0, gs_modes['num_of_solutions'], 1)
        m_r_list = np.arange(0, rs_modes['num_of_solutions'], 1)

        g_energy, r_energy = np.meshgrid((gs_modes['eigenenergies'] * time / hbar).to('dimensionless').magnitude,
                                         (rs_modes['eigenenergies'] * time / hbar).to('dimensionless').magnitude,
                                         indexing='ij')

        rho, _ = np.meshgrid(self.thermal_distribution_rho_ground(X), m_r_list, indexing='ij')
        mat_els = np.empty([len(m_g_list), len(m_r_list)], dtype=np.complex128)
        exp_val = np.exp(1j * (g_energy - r_energy))

        for m_g in m_g_list:
            for m_r in m_r_list:
                mat_els[m_g, m_r] = self.get_matrix_element_q(trap_detuning=trap_detuning,
                                                              m_ground=m_g, m_rydberg=m_r, X=X)

        lattice_signal = np.sum(rho * np.abs(mat_els) ** 2 * exp_val)
        return np.abs(lattice_signal) ** 2

    def calc_lattice_signal_t(self, tlist, trap_detuning):
        """
        Wrapper for calc_lattice_signal to calculate the decay for an array of times
        :param tlist:
        :param trap_detuning:
        :return:
        """
        
        if self.__multiprocessing:
            func = partial(self.calc_lattice_signal, trap_detuning=trap_detuning)
            with Pool() as p:
                r = process_map(func, tlist)
            return np.array(r)
        else:
            latt_sigs = np.empty(shape=len(tlist))
            for i in tqdm(list(range(len(tlist)))):
                latt_sigs[i] = self.calc_lattice_signal(tlist[i], trap_detuning=trap_detuning)
        return latt_sigs

    @lru_cache(maxsize=65536)
    def calc_non_lattice_signal(self, time):
        """

        :param time:

        :param trap_depth:
        :param trap_wavelength:
        :param L:
        :return:
        """

        X = 0

        # gradient of intensity along axial direction
        a = d_fun(lambda x: self.intensity(x, rho=0 * u.m), 0 * u.m)
        # curvature of intensity along radial direction
        b = d2_fun(lambda r: self.intensity(X=0 * u.m, rho=r), 0 * u.m) / 2

        tau_X = hbar * 2 / (np.abs(self.alpha_f()) * self.trap_depth(X) / (2 * self.alpha_g()) * (
                1 - self.theta_n()) * a * self.cloud_length)
        tau_rho = hbar * (
                1 / (np.abs(self.alpha_f() * self.trap_depth(X)) / (2 * self.alpha_g()) * (1 - self.theta_n()) * b) * (
                1 / self._waist_12 ** 2 + 2 / self._waist_23 ** 2 - b * self.trap_depth(X) / (
                2 * u.k_B * self.T_atoms)))

        return np.exp(-2 * time ** 2 / tau_X ** 2) / (1 + time ** 2 / tau_rho ** 2)

    def calc_non_lattice_signal_t(self, tlist):
        """
        Wrapper for calc_non_lattice_signal to calculate the decay for an array of times
        :param tlist:
        :return:
        """
        nl_sigs = np.empty(shape=len(tlist))
        for i in tqdm(list(range(len(tlist)))):
            nl_sigs[i] = self.calc_non_lattice_signal(tlist[i])
        return nl_sigs

    @lru_cache(maxsize=65536)
    def tau_6p_0(self):
        """
        Calculate the lifetime of the 6P_3/2 state at room temperature
        :return: lifetime in seconds
        """
        return atom.getStateLifetime(6, 1, 3 / 2, temperature=self.room_temp.to(u.K).magnitude,
                                     includeLevelsUpTo=50) * u.s

    @lru_cache(maxsize=65536)
    def tau_rydberg_ns(self):
        """
        Calculate the lifetime of the Rydberg state at room temperature
        :return: lifetime in seconds
        """
        return atom.getStateLifetime(self.rydberg_n, 0, 1 / 2, temperature=self.room_temp.to(u.K).magnitude,
                                     includeLevelsUpTo=self.rydberg_n + 30) * u.s

    @lru_cache(maxsize=65536)
    def tau_eff(self, trap_detuning):
        """
        Calculate the radiative lifetime of the Rydberg state including the dressing to the 6P_3/2 state,
        as well as the lifetime of the Rydberg state itself at room temperature.

        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :return: excitation probability of the Rydberg state at time t (0,1)
        """
        X = 0
        D_n = self.reduced_dipole_matrix_element_n()
        tau_6p = self.alpha_g() / self.trap_depth(X) * (16 * hbar ** 2 * trap_detuning ** 2 * self.tau_6p_0()) / (
                D_n ** 2 * (self.trap_power_ratio() ** (-1 / 4) + self.trap_power_ratio() ** (1 / 4)) ** 2)

        return 1 / (1 / tau_6p + 1 / self.tau_rydberg_ns())

    def calc_decay(self, tlist, trap_detuning):
        """
        Calculate the retrieval efficiency of the excitation of the Rydberg state after time t
        under influence of spontaneous emission.

        :param tlist: list of times
        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :return: excitation probability of the Rydberg state at time t (0,1)
        """

        return np.exp(-tlist / self.tau_eff(trap_detuning))

    @lru_cache(maxsize=65536)
    def omega_d(self, X, rho, trap_detuning):
        """
        Calculate the energy difference in the non-lattice potential between the ground and Rydberg state.
        :param X: axial position in the trap
        :param rho: transversal distance rho relative to the beam axis of the trap
        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :return: energy difference as frequency
        """

        return (self.rydberg_state_non_lattice_potential(X=X, rho=rho, trap_detuning=trap_detuning)
                - self.ground_state_non_lattice_potential(X=X, rho=rho)) / hbar

    @lru_cache(maxsize=65536)
    def excitation_retrieval_overlap(self, X, rho):
        """
        Calculate the overlap of the excitation and retrieval pulses.
        :param X: axial position in the trap
        :param rho: transversal distance rho relative to the beam axis of the trap
        :return: overlap
        """
        return (self.waist_12(0 * u.m) / self.waist_12(X) * np.exp(-rho ** 2 / self.waist_12(X) ** 2)
                * (self.waist_23(0 * u.m) / self.waist_23(X) * np.exp(
                    -rho ** 2 / self.waist_23(X) ** 2)) ** 2).magnitude

#    @lru_cache(maxsize=65536)
    def atomic_density(self, X, rho):
        """
        Calculate the atomic density in the trap.
        :param X: (ndarray) axial position in the trap
        :param rho: (ndarray) transversal distance rho relative to the beam axis of the trap
        :return: atomic density
        """
        return (np.exp(-(self.ground_state_non_lattice_potential(X, rho)
                         + self.ground_state_lattice_potential(X, rho)) / (u.k_B * self.T_atoms))
                * np.exp(-X ** 2 / self.cloud_length ** 2)) * 1 / (u.um ** 3)

    @lru_cache(maxsize=65536)
    def total_motional_dephasing(self, X, rho, Ts, trap_detuning):
        """
        Calculate the retrieval efficiency of the excitation of the Rydberg state at time Ts using the full integration (eq. 27).
        :param X: axial position in the trap
        :param rho: transversal distance rho relative to the beam axis of the trap
        :param Ts: measurement time
        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        """
        gs_modes = self.ground_state_quantized_motion(X)
        rs_modes = self.rydberg_state_quantized_motion(trap_detuning=trap_detuning, X=X)

        m_g_list = np.arange(0, gs_modes['num_of_solutions'], 1)
        m_r_list = np.arange(0, rs_modes['num_of_solutions'], 1)

        g_energy, r_energy = np.meshgrid((gs_modes['eigenenergies'] * Ts / hbar).to('dimensionless').magnitude,
                                         (rs_modes['eigenenergies'] * Ts / hbar).to('dimensionless').magnitude,
                                         indexing='ij')

        rho_th, _ = np.meshgrid(self.thermal_distribution_rho_ground(X), m_r_list, indexing='ij')
        mat_els = np.empty([len(m_g_list), len(m_r_list)], dtype=np.complex128)
        exp_quantized_energy = np.exp(1j * (g_energy - r_energy))
        exp_transversal_energy = np.exp(-1j * self.omega_d(X, rho, trap_detuning) * Ts).magnitude

        for m_g in m_g_list:
            for m_r in m_r_list:
                mat_els[m_g, m_r] = self.get_matrix_element_q(trap_detuning=trap_detuning,
                                                              m_ground=m_g, m_rydberg=m_r, X=X)

        motional_dephasing = exp_transversal_energy * np.sum(rho_th * np.abs(mat_els) ** 2 * exp_quantized_energy)
        return motional_dephasing

    @lru_cache(maxsize=65536)
    def _efficiency_element(self, X, rho, Ts, trap_detuning):
        """
        Utility function for integration of efficiency_integration
        """

        X_um = X * u.um
        rho_um = rho * u.um

        return (rho * self.total_motional_dephasing(X_um, rho_um, Ts, trap_detuning)
                * self.atomic_density(X_um, rho_um).magnitude
                / np.sum(self.atomic_density(X_um, rho_um).magnitude)
                * self.excitation_retrieval_overlap(X_um, rho_um))

    @lru_cache(maxsize=65536)
    def efficiency_integration(self, Ts, trap_detuning):
        """
        Calculate the efficiency of the excitation of the Rydberg state at time Ts using the full integration (eq. 26).
        :param Ts: measurement time
        :param trap_detuning: detuning of the trapping laser w.r. to the state 6P_3/2
        :return: efficiency of the excitation of the Rydberg state at time Ts
        """

        x_list = np.linspace(-self.cloud_length, self.cloud_length, 25).to('um').magnitude
        rho_list = np.linspace(0, 15, 15)

        dx = x_list[1] - x_list[0]
        drho = rho_list[1] - rho_list[0]

        efficiency = double_complex_integral(self._efficiency_element, x_list[0], x_list[-1], rho_list[0], rho_list[-1],
                                             Ts=Ts, trap_detuning=trap_detuning)

        return np.abs(efficiency) ** 2 * np.exp(-Ts / self.tau_eff(trap_detuning)).magnitude

    def efficiency_integration_t(self, tlist, trap_detuning):
        """
        Wrapper for efficiency_integration to calculate the efficiency for an array of times
        :param tlist:
        :param trap_detuning:
        :return:
        """
        if self.__multiprocessing:
            with Pool() as p:
                func = partial(self.efficiency_integration, trap_detuning=trap_detuning)
                effs = process_map(func, tlist)
            return effs
        else:
            effs = np.empty(shape=len(tlist))
            for i in tqdm(list(range(len(tlist)))):
                effs[i] = self.efficiency_integration(tlist[i], trap_detuning=trap_detuning)
            return effs

    def efficiency_integration_delta(self, Ts, trap_detuning):
        """
        Wrapper for efficiency_integration to calculate the efficiency for an array of deltas
        :param Ts: measurement time
        :param trap_detuning: list of detunings
        :return:
        """
        if self.__multiprocessing:
            with Pool() as p:
                func = partial(self.efficiency_integration, Ts)
                effs = process_map(func, trap_detuning)
            return effs
        else:
            effs = np.empty(shape=len(trap_detuning))
            for i in tqdm(list(range(len(trap_detuning)))):
                effs[i] = self.efficiency_integration(Ts, trap_detuning=trap_detuning[i])
            return effs
