"""
This script plots the mean and standard deviation of the light shift as a function of the detuning for Figure 3.
Figure is stored in current folder.
"""

import numpy as np
import matplotlib.pyplot as plt
from time import strftime
from scipy.io import loadmat

import simulation_class as rsd

u = rsd.u

plt.style.use("paperstyle.mplstyle")

datestr = strftime('%Y-%m-%d')
pt = 1 / 72  # in inches

# %% simulate detuning sample

cg_free = rsd.ColdGas(
    rydberg_n=76,
    waist_12=5 * u.micrometer,
    waist_23=40 * u.micrometer,
    waist_trap_p=24 * u.micrometer,
    waist_trap_m=24 * u.micrometer,
    X_0_p=600 * u.micrometer,
    X_0_m=0 * u.micrometer,
    P_p=17 * u.mW,
    P_m=0.75 * 17 * u.mW,
    cloud_length=20 * u.micrometer,
    T_atoms=2e-6 * u.kelvin,
    room_temp=293 * u.K,
    n_of_x=1000,
    truncation=9,
    multiprocessing=False,
    rydberg_only_free_electron=True,
)

setup_specified_text = 'RQO'

# %% reading in

mat_contents = loadmat('data_FIG3_mean_std_vs_detuning.mat')
potential_2D_diff_lattice_loaded = mat_contents['potential_2D_diff_lattice']
potential_2D_diff_non_lattice_loaded = mat_contents['potential_2D_diff_non_lattice']
potential_2D_diff_all_loaded = mat_contents['potential_2D_diff_all']
potential_2D_diff_running_wave_loaded = mat_contents['potential_2D_diff_running_wave']
detunings_2D_maps_loaded = mat_contents['factors_detuning'][0]
X_potential_plot = mat_contents['X_potential_plot'][0] * u.m
rho_potential_plot = mat_contents['rho_potential_plot'][0] * u.m

potential_2D_gs_lattice = np.zeros(shape=(len(rho_potential_plot), len(X_potential_plot)))
potential_2D_gs_running_wave = np.zeros(shape=(len(rho_potential_plot), len(X_potential_plot)))

for i, rho in enumerate(rho_potential_plot):
    potential_2D_gs_lattice[i] = np.vectorize(
        lambda X_: cg_free.ground_state_lattice_potential(X_ * u.m, rho).to('k_B*uK').magnitude, otypes=[np.ndarray])(
        X_potential_plot)

    potential_2D_gs_running_wave[i] = np.vectorize(
        lambda X_: cg_free.ground_state_running_wave_potential(X_ * u.m, rho).to('k_B*uK').magnitude,
        otypes=[np.ndarray])(
        X_potential_plot)

# %% Extracting mean or standard deviation of light shift within spatial region

factors = detunings_2D_maps_loaded

uKtokHz = (1 * u.boltzmann_constant / u.hbar * u.uK).to('Hz').magnitude * 1e-3

temperature_fraction = [2]  # [0.5, 1, 2, 4]

mean_pot_diff_lattice = np.zeros(shape=(len(temperature_fraction), len(factors)))
mean_pot_diff_both = np.zeros(shape=(len(temperature_fraction), len(factors)))
mean_pot_diff_running_wave = np.zeros(shape=(len(temperature_fraction), len(factors)))
std_pot_diff_lattice = np.zeros(shape=(len(temperature_fraction), len(factors)))
std_pot_diff_both = np.zeros(shape=(len(temperature_fraction), len(factors)))
std_pot_diff_running_wave = np.zeros(shape=(len(temperature_fraction), len(factors)))

rho_scaling_factors = np.repeat([rho_potential_plot.magnitude / max(rho_potential_plot.magnitude)],
                                len(X_potential_plot), axis=0).transpose()

index_temp = 0

for index_temp, temperature_factor in enumerate(temperature_fraction):
    for index, factor_detuning in enumerate(detunings_2D_maps_loaded):
        summed_array_lattice = np.array(potential_2D_diff_lattice_loaded[index, :, :])
        summed_array_lattice[np.where(abs(rho_potential_plot.magnitude * 1e6) > cg_free.waist_12(0).magnitude * 2)[0],
        :] = 0
        summed_array_lattice = summed_array_lattice * abs(rho_scaling_factors)
        summed_array_lattice[
            potential_2D_gs_lattice > potential_2D_gs_lattice.min() + cg_free.T_atoms.magnitude * 1e6 * temperature_factor] = 0
        summed_array_lattice_no_zeros = summed_array_lattice[summed_array_lattice != 0]
        mean_pot_diff_lattice[index_temp, index] = np.mean(summed_array_lattice_no_zeros) * uKtokHz
        std_pot_diff_lattice[index_temp, index] = np.std(summed_array_lattice_no_zeros) * uKtokHz

        summed_array_both = np.array(potential_2D_diff_all_loaded[index, :, :])
        summed_array_both[np.where(abs(rho_potential_plot.magnitude * 1e6) > cg_free.waist_12(0).magnitude * 2)[0],
        :] = 0
        summed_array_both = summed_array_both * abs(rho_scaling_factors)
        summed_array_both[
            potential_2D_gs_lattice > potential_2D_gs_lattice.min() + cg_free.T_atoms.magnitude * 1e6 * temperature_factor] = 0
        summed_array_both_no_zeros = summed_array_both[summed_array_both != 0]
        mean_pot_diff_both[index_temp, index] = np.mean(summed_array_both_no_zeros) * uKtokHz
        std_pot_diff_both[index_temp, index] = np.std(summed_array_both_no_zeros) * uKtokHz

        summed_array_running_wave = np.array(potential_2D_diff_running_wave_loaded[index, :, :])
        summed_array_running_wave[
        np.where(abs(rho_potential_plot.magnitude * 1e6) > cg_free.waist_12(0).magnitude * 2)[0], :] = 0
        summed_array_running_wave = summed_array_running_wave * abs(rho_scaling_factors)
        summed_array_running_wave[
            potential_2D_gs_running_wave > potential_2D_gs_running_wave.min() + cg_free.T_atoms.magnitude * 1e6 * temperature_factor] = 0
        summed_array_running_wave_no_zeros = summed_array_running_wave[summed_array_running_wave != 0]
        mean_pot_diff_running_wave[index_temp, index] = np.mean(summed_array_running_wave_no_zeros) * uKtokHz
        std_pot_diff_running_wave[index_temp, index] = np.std(summed_array_running_wave_no_zeros) * uKtokHz

mean_pot_diff_lattice = np.nan_to_num(mean_pot_diff_lattice, nan=0)
std_pot_diff_lattice = np.nan_to_num(std_pot_diff_lattice, nan=0)
mean_pot_diff_both = np.nan_to_num(mean_pot_diff_both, nan=0)
std_pot_diff_both = np.nan_to_num(std_pot_diff_both, nan=0)
mean_pot_diff_running_wave = np.nan_to_num(mean_pot_diff_running_wave, nan=0)
std_pot_diff_running_wave = np.nan_to_num(std_pot_diff_running_wave, nan=0)

print(factors[np.abs(mean_pot_diff_both).argmin()] * cg_free.calc_magic_detuning().magnitude / 2)
print(factors[std_pot_diff_both.argmin()] * cg_free.calc_magic_detuning().magnitude / 2)

# %% Plot extracted light shift vs trap detuning for one temperature
# but take the mean or variance of contributing atoms in potential area

color_rw = '#004d9f'  # '#004d9f' (uni bonn blue),    '#113060'
color_sw = '#c43f53'  # '#a11e3b' (uni tübingen red), '#c43f53'
color_periodic = '#fcba00'  # '#fcba00' (uni bonn yellow),  '#e3a239'

plat = []
pboth = []
prw = []

fig = plt.figure(figsize=(246 * pt, 135 * pt), dpi=800, constrained_layout=True)
plt.axhline(0, color='grey')
# plot magic detunings
plt.axvline(cg_free.calc_magic_detuning().magnitude / 2 * 1e3, linestyle='dashed', color=color_periodic)
plt.text(cg_free.calc_magic_detuning().magnitude / 2 * (1 + 1 / 300) * 1e3, -11.5, r'$\Delta^\sim_{76}$',
         verticalalignment='bottom', horizontalalignment='left', color=color_periodic)

plt.axvline(cg_free.calc_magic_detuning_running_wave().magnitude / 2 * 1e3, linestyle='dashed', color=color_rw)
plt.text(cg_free.calc_magic_detuning_running_wave().magnitude / 2 * (1 + 1 / 300) * 1e3, -11.5,
         r'$\Delta^\mathrm{rw}_{76}$',
         verticalalignment='bottom', horizontalalignment='left', color=color_rw)

plt.axvline(413.4, ymin=0.15, ymax=0.27, linestyle='dotted', color=color_sw)
plt.text(detunings_2D_maps_loaded[np.argmin(std_pot_diff_both)] * cg_free.calc_magic_detuning().magnitude / 2 * (
        1 + 1 / 600) * 1e3, -11.5, r'$\Delta^\mathrm{sw}_{76}$',
         verticalalignment='bottom', horizontalalignment='left', color=color_sw)

for i in np.arange(0, len(temperature_fraction)):
    if len(temperature_fraction) == 1:
        alpha_temp = 1
    else:
        alpha_temp = (temperature_fraction[i] - 0.5) * 0.7 / 3.5 + 0.3
    prw_temp, = plt.plot(factors * cg_free.calc_magic_detuning().magnitude / 2 * 1e3, std_pot_diff_running_wave[i, :],
                         '-',
                         label='Running wave', color=color_rw, alpha=alpha_temp)
    pboth_temp, = plt.plot(factors * cg_free.calc_magic_detuning().magnitude / 2 * 1e3, std_pot_diff_both[i, :], '-',
                           label='Standing wave', color=color_sw, alpha=alpha_temp)
    plat_temp, = plt.plot(factors * cg_free.calc_magic_detuning().magnitude / 2 * 1e3, std_pot_diff_lattice[i, :], '-',
                          label='Periodic', color=color_periodic, alpha=alpha_temp)

handles, labels = plt.gca().get_legend_handles_labels()

plt.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 1.14), ncol=3,
           handletextpad=0.4, borderpad=0.3, borderaxespad=0.2)

plt.ylim([-15, 100])
plt.xlim([280, 580])
plt.xlabel(r'Trap detuning $\Delta/(2\pi\cross\mathrm{MHz})$', fontsize=7.5, labelpad=3)
plt.ylabel(r'Differential light shift (kHz)', fontsize=7.5, labelpad=3)
plt.minorticks_on()
plt.tight_layout(pad=0.05)
plt.savefig(datestr.replace('-', '') + f'-FIG3-mean-st-light-shift-vs-detuning.pdf', dpi=800)
plt.show()
