# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:light
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.16.1
#   kernelspec:
#     display_name: Python 3 (ipykernel)
#     language: python
#     name: python3
# ---

# +
import os
import sys
os.environ['QT_QPA_PLATFORM']='offscreen'

# Scientific
import numpy  as np
from scipy.optimize import curve_fit

# Graphics
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import rcParams
from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

rcParams.update({
  "figure.dpi": 300,
  'text.usetex': True,
  'text.latex.preamble': r'\usepackage{amsfonts,nicefrac}',
  'mathtext.fontset': "stix",
  'font.family': 'STIXGeneral'
})


cm = 1.0/2.54  # centimeters in inches

goldenratio=1.618
# -

# # Definitions

# +
# We use detuning as the frequency unit, while all data files use omega0 as frequency unit
DeltaOverW0 = 1.0/2.0

basedir = '../'
data_dir = basedir+'data/'
fig_dir = basedir+'figures/'


def k_eff_nph(g,k,nph,Delta):
    return 16*(g**2/Delta)**2 *nph/k

def nph_eps(k,eps):
    return (2*eps/k)**2

def k_eff_eps(g,k,eps,Delta):
    return k_eff_nph(g,k,nph_eps(k,eps),Delta)

def Mancini_squeezing(t, N, k_eff):
    return np.exp(k_eff*t) /(1 + N*k_eff*t)

def contrastAnalytic(t):
    return np.exp(-t)

def varXAnalytic(tau,J):
    return J*(1-np.exp(-tau))**2+(1-np.exp(-2*tau))/2

def varYAnalytic(tau,J):
    return (J*(1-np.exp(-2*tau))+(1+np.exp(-2*tau))/2)

def varZAnalytic(tau,J):
    return 1/(1+2*J*tau)

def xiAnalyticC(tau,cos2th,J):
    ''' Conditional spin squeezing '''
    sin2th= 1- cos2th
    return (varZAnalytic(tau,J)*sin2th+varXAnalytic(tau,J)*cos2th)/contrastAnalytic(tau)

def xiAnalytic(tau,J):
    '''  Average spin squeezing in large J limit '''
    cos2th=1/(2.0*J*contrastAnalytic(tau))
    return xiAnalyticC(tau,cos2th,J)

def check_outbadcavity(kappa,Na,g,Delta):
    is_not_badcavity = (kappa < Na*g**2/Delta)
    return is_not_badcavity

def choose_marker_type(k, N, g, Delta, color):
    if (check_outbadcavity(k,N,g,Delta)):
        mfc = 'none'
    else:
        mfc = color
    return mfc

def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx
  

def save_figure(filename):
    plt.savefig(filename+".pdf",format = 'pdf',bbox_inches = 'tight', dpi=300)
    plt.savefig(filename+".png",format = 'png',bbox_inches = 'tight', dpi=300)
    #plt.savefig(filename+".eps",format = 'eps',bbox_inches = 'tight', dpi=300)
    return

## INDEXES

# index for final data files
iE      = 0
iK      = 1
iG      = 2
iN      = 3
iNph    = 4
iTo     = 5
iTo_err = 6
iKU     = 7
iKU_err = 8

# index for columns in data files
i_t = 0

i_ad_KU     = 1 
i_ad_KU_err = 2
i_ad_Photon = 7
i_ad_Ph_err = 8
i_ad_Jz     = 19
i_ad_Jz_err = 20

# For trajectories
it = 0
iXi = 1
iXierr = 2
iPh = 3
iPherr = 4
iJzi = 5
iJzierr = 6

# Legends 
color_array_n = ['xkcd:golden yellow','xkcd:bright orange','xkcd:kelly green', 'xkcd:violet']
marker_array_n = ['o','v','s','^']

color_array_k = ['xkcd:pink','xkcd:cerulean','xkcd:navy blue', 'xkcd:red','xkcd:gray']
marker_array_k = ['D','P']


# -

# # Dependence on $N$

# ## $\xi^2$ scaling

# +
def plot_final_comparison():
    # legend
    legend_elem = [ Line2D([0], [0], marker='D', color='w', mec='red', mfc='red',label = 'cavity removal',markersize=3),
                    Line2D([0], [0], color='purple', linestyle = '-.', linewidth=1, label = "$(3/2)/N^{2/3}$"),
                    Line2D([0], [0], color='b', linestyle = ':',linewidth=1, label = r'$e/N$')]
    
    Delta = 1
    # internal routine to plot squeezing
    def plot_xi(ax, data_array):  
        params = []
        
        for i in range(len(data_array)):
            e = data_array[i,iE]  
            k = data_array[i,iK]
            g = data_array[i,iG]
            n = data_array[i,iN]
            nph = data_array[i,iNph]

            if ([k,e] in params):
                i_color = params.index([k,e])
            else:
                params.append([k,e])
                i_color = params.index([k,e])
                
            color_line = color_array_n[i_color]
            mface = choose_marker_type(k, n, g, Delta, color_line)
            mark = marker_array_n[i_color]
            
            ax.errorbar(n,data_array[i,iKU],yerr=data_array[i,iKU_err], c=color_line, mfc = mface, linestyle='none', marker=mark,markersize=5,zorder=4)
            
    # plot specifics
    imagewidth = 10 # cm
    fig, axs = plt.subplots(1,1,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300)

    Nmin=30
    Nmax=22000

    X_vec = np.linspace(Nmin, Nmax,100)

    # Fixed scaling lines
    # SQL
    axs.hlines([1.0],0,Nmax,color='k',linewidth=1,linestyle="--")
    # Thomsen,Mancini,Wiseman Heisenberg scaling
    axs.plot(X_vec, np.e/(X_vec),c='b',linestyle=':',linewidth=1,zorder=6)
    # Analytical scaling
    axs.plot(X_vec, 3/(2*X_vec**(2/3.0)),c='purple', linestyle = '-.',linewidth=1,zorder=6)

    # load data
    # Cavity removal
    CavityRemovalSqueezing = np.loadtxt(data_dir+'minimum/opt_xi_cavity_removal.txt', unpack='true')
    axs.errorbar(CavityRemovalSqueezing[0],CavityRemovalSqueezing[2], CavityRemovalSqueezing[3], linestyle='none', marker = 'D',c='red',zorder=2,markersize=3)

    # full simulations
    data_squeezing = np.transpose(np.loadtxt(data_dir+'minimum/final_squeezing_N.txt', unpack='true'))
    plot_xi(axs, data_squeezing)

    axs.set_xlabel(r'$N$')
    axs.set_ylabel(r'$\xi_m^2$')
    axs.set_ylim(1e-3,1.5)
    axs.set_xlim(1,Nmax)
    axs.set_xscale('log')
    axs.set_yscale('log')

    axs.legend(handles = legend_elem, loc=(0.02,0.02),frameon=True, framealpha=1,edgecolor="w", handlelength=1)

    axs.annotate(r'SQL',(4000,0.55))
    axs.tick_params(axis='both',which='both',direction="in")
    save_figure(f'{fig_dir}final_comparison')

plot_final_comparison()
# -
# ## Time scaling


# +
def plot_final_comparison_timescale():  
    Delta = 1
    
    def plot_t_reduced(ax, data_array, alpha, offset):
        params = []
        
        for i in range(len(data_array)):
            e = data_array[i,iE]  
            k = data_array[i,iK]
            n = int(data_array[i,iN])
            g = data_array[i,iG]
            nph = data_array[i,iNph]

            k_eff = k_eff_eps(g,k,e,Delta) 
            
            if ([k,e] in params):
                i_color = params.index([k,e])
            else:
                params.append([k,e])
                i_color = params.index([k,e])

            color_line = color_array_n[i_color]
            mface = choose_marker_type(k, n, g, Delta, color_line)
            mark = marker_array_n[i_color]

            if (offset==True):
                off=6*k_eff/k
            else:
                off=0

            ax.errorbar(n, data_array[i,iTo]*k_eff*(n**alpha)-off, data_array[i,iTo_err]*k_eff*(n**alpha), c=color_line, mfc = mface, linestyle='none', marker=mark,markersize=5,zorder=4)
            ax.set_xscale('log')

    time_legend_elem = [ Line2D([0], [0], marker='D', color='w', mec='red', mfc='red',label = 'cavity removal',markersize=3),
    Line2D([0], [0], color='purple', linestyle = '-.', linewidth=1, label = f'$1/N^{{1/3}}$')] 

    imagewidth = 10 # cm
    fig, axs = plt.subplots(1,1,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300, constrained_layout=True)

    X_vec = np.linspace(5,20000,100)

    
    # load data
    # Cavity removal
    CavityRemovalSqueezing = np.loadtxt(data_dir+'minimum/opt_xi_cavity_removal.txt', unpack='true')
    cav_keff = CavityRemovalSqueezing[1,0] 
    axs.errorbar(CavityRemovalSqueezing[0],CavityRemovalSqueezing[4]*cav_keff, CavityRemovalSqueezing[5]*cav_keff, linestyle='none',  marker = 'D',c='red',markersize=3,zorder=2)
    
    # full simulations
    data_squeezing = np.transpose(np.loadtxt(data_dir+'minimum/final_squeezing_N.txt', unpack='true'))
    plot_t_reduced(axs, data_squeezing, 0,False)

    axs.plot(X_vec, 1/(X_vec**(1/3)),c='purple', linestyle = '-.',linewidth=1,zorder=6)

    axs.set_xlabel(r'$N$')
    axs.set_ylabel(r'$\tilde{\kappa}\; t_m$')
    axs.set_xscale('log')
    axs.set_yscale('log')
    axs.set_xlim(3,25000)
    axs.tick_params(axis='both',which='both',direction="in")

    axs.legend(handles = time_legend_elem, loc='lower left',frameon=True, edgecolor="white",framealpha=1, handlelength=1)

    axins = inset_axes(axs, width="33%", height="33%")
    axins.plot(X_vec, 1/(X_vec**(1/3)),c='purple', linestyle = '-.',linewidth=1,zorder=6)

    plot_t_reduced(axins,data_squeezing, 0,True)
    axins.errorbar(CavityRemovalSqueezing[0],CavityRemovalSqueezing[4]*cav_keff, CavityRemovalSqueezing[5]*cav_keff, linestyle='none',  marker = 'D',c='red',markersize=3,zorder=2)
    axins.set_xscale('log')
    axins.set_yscale('log')
    axins.set_xlim(3,300)
    axins.set_ylim(0.06,1.)
    axins.set_ylabel(r'$\tilde{\kappa} \left(t_m-2c/\kappa\right)$',fontsize=8,y=0.5,labelpad=1.5)
    axins.tick_params(labelsize=8)
    axins.tick_params(axis='both',which='both',direction="in")

    save_figure(f'{fig_dir}final_comparison_timescale')

plot_final_comparison_timescale()


# -
# # Dependence on $g$

# ## Dependence of optimal time on $\tilde{\kappa}/\kappa$

# +
def plot_t_opt_keff():
    Delta = 1
    params = []
    natoms = []
    
    def plot_t_g_nfixed(results, ax, beta_t):  
        for i in range(len(results)):
            k = results[i,iK]
            e = results[i,iE]
            t  = results[i,iTo]
            terr = results[i,iTo_err]
            n  = results[i,iN]
            g  = results[i,iG]
            
            k_eff = k_eff_eps(g, k, e, Delta)
            
            if ([k,e] in params):
                i_c = params.index([k,e])
            else:
                params.append([k,e])
                i_c = params.index([k,e])
            
            if (n in natoms):
                i_m = natoms.index(n)
            else:
                natoms.append(n)
                i_m = natoms.index(n)
            
            color_line = color_array_k[i_c]
            mface = choose_marker_type(k, n, g, Delta, color_line)
            mark = marker_array_k[i_m]
            
            tscale = k_eff*(n**(beta_t))

            ax.errorbar(k_eff/k, t*tscale, terr*tscale ,c=color_line, mec = color_line, mfc = mface, marker = mark, markersize=5)

        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_ylabel(r'$\tilde{\kappa}\; t_m\;  N^{\beta}$')    
        ax.set_xlabel(r'$\tilde{\kappa}/\kappa $ ')
    
    def f_vec(ktildeok,beta,N,b,c):
        return b + 2*ktildeok * c*N**beta

    imagewidth=10
    fig,axs=plt.subplots(1,1,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300,sharex=True)
    
    Na = 20
    beta_t = 0.32
    b = 0.86
    c = 3.0
    t_vec = np.logspace(-4,np.log10(0.7),40)
    axs.plot(t_vec, f_vec(t_vec,beta_t,Na,b,c),linestyle =':',color='purple',linewidth=1.5)

    data_G = np.transpose(np.loadtxt(data_dir+'minimum/final_xi_t_g.txt', unpack='true'))
    plot_t_g_nfixed(data_G, axs, beta_t)
    
    leg_handles_t = []
    for i in range(len(natoms)):
        leg_handles_t.append(Line2D([0], [0], marker=marker_array_k[i], color='w', markerfacecolor='k', label=f'$N = {int(natoms[i])}$',markersize=5))
    leg_handles_t.append( Line2D([0], [0], color='purple', linewidth=1.5, linestyle =':', label = r'$t_m = b/(\tilde{\kappa}N^\beta)+2 c/\kappa$'))

    legend1 = axs.legend(handles = leg_handles_t, loc='upper left',frameon=False, edgecolor="white",framealpha=1, handlelength=1)

    axs.set_yscale('log')
    axs.set_xlim(1e-4,50)
    axs.set_ylim(0.5,100)
    axs.tick_params(axis='both',which='both',direction="in")
    save_figure(f'{fig_dir}t_opt_keff')
    
plot_t_opt_keff()


# -
# ## Dependence of optimal squeezing on $g^2/\kappa\Delta$

# +
def plot_xi_g_nfixed(results, ax, Na, bc_cond): 
    finitetimestepbias=1.5e-2
    Delta = 1
    params = []
    natoms = [Na]
    
    for i in range(len(results)):
        k = results[i,iK]
        e = results[i,iE]
        xi = results[i,iKU]
        xerr = results[i,iKU_err] + finitetimestepbias
        n  = int(results[i,iN])
        g  = results[i,iG]

        if (n in natoms):
            i_m = natoms.index(n)    
            if ([k,e] in params):
                    i_c = params.index([k,e])
            else:
                params.append([k,e])
                i_c = params.index([k,e])
            
            color_line = color_array_k[i_c]
            mface = choose_marker_type(k, n, g, Delta, color_line)
            mark = marker_array_k[i_m]

        
            k_eff = k_eff_eps(g, k, e, Delta)

            if (bc_cond):
                x = g**2/(k)
                ax.set_ylabel(r'$\xi^2_m$') 
                ax.set_xlabel(r'$g^2/\kappa\Delta $ ')
                msize = 5
            else:
                x = k_eff/k
                msize = 5
                ax.set_xlabel(r'$\tilde{\kappa}/\kappa $ ')

            ax.errorbar(x, xi, xerr, c=color_line, mec = color_line, mfc = mface, marker = mark, markersize=msize,markeredgewidth=0.8,elinewidth=0.8)

    ax.set_xscale('log')
    ax.set_yscale('log')   
    

    
def plot_xi_dep_g2k():
    imagewidth=10
    fig,axs=plt.subplots(1,1,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300)
    Na = 45 # selection of a single case
    bclim = 1/Na
    bad_cavity = True # scaling criterion is plotting against g^2/k or k_eff/k

    # From cavity removal simulation with Na=45
    cavremsq=1.446e-01   

    axs.vlines(bclim,0.1,1.001,color="#80d0d0",zorder=0.4,lw=1,ls="--") 

    axs.plot(np.linspace(0,1,3),np.ones(3),c='k',linestyle='--',linewidth=1)
    axs.annotate(r'SQL',(0.1,0.85))

    axs.plot(np.linspace(0,bclim,3),cavremsq*np.ones(3),ls=':', linewidth=1.5, c="purple")
    
    data_G = np.transpose(np.loadtxt(data_dir+'minimum/final_xi_t_g.txt', unpack='true'))
    plot_xi_g_nfixed(data_G,axs,Na,bad_cavity)
   
    axs.set_xlim([0.001,1])
    axs.set_ylim([0.1,1.1])

    add_handle=[Line2D([0], [0], marker=marker_array_k[0], color='w', markerfacecolor='k', label=r'$N = 45$',markersize=5),
                Line2D([0], [0], color='purple', linewidth=1.5, linestyle =':', label = r'cavity removal')]

    axs.legend(handles = add_handle,loc="center left",frameon=False, handlelength=1)

    axs.tick_params(axis='both',which='both',direction="in")
    save_figure(f'{fig_dir}xi_dep_g2k')

plot_xi_dep_g2k()


# +
def plot_xi_g2k_comparison():
    imagewidth=12
    fig,ax=plt.subplots(1,2,figsize=[2*imagewidth*cm,imagewidth*cm/goldenratio], dpi=300)
    Na = 45 # selection of a single case
    bclim = 1/Na

    # From cavity removal simulation with Na=45
    cavremsq=1.446e-01   
    
    data_G = np.transpose(np.loadtxt(data_dir+'minimum/final_xi_t_g.txt', unpack='true'))
    plot_xi_g_nfixed(data_G,ax[0],Na,True)  # g^2/k
    plot_xi_g_nfixed(data_G,ax[1],Na,False) # k_eff
    
    ax[0].set_xlim([0.001,1])
    ax[0].set_ylim([0.1,1.1])
    ax[1].set_xlim([0.001,10])
    ax[1].set_ylim([0.1,1.1])
    ax[1].annotate(r'SQL',(0.1,0.85))
    
    for axs in ax:
        axs.vlines(bclim,0.1,1.001,color="#80d0d0",zorder=0.4,lw=1,ls="--") 

        axs.plot(np.linspace(0,10,3),np.ones(3),c='k',linestyle='--',linewidth=0.5)

        axs.plot(np.linspace(0,bclim,3),cavremsq*np.ones(3),ls=':', linewidth=1, c="purple")
        axs.tick_params(axis='both',which='both',direction="in")

    
    save_figure(f'{fig_dir}xi_dep_comparison')

plot_xi_g2k_comparison()


# -

# # Trajectories in conditional dynamics
#
# - Correlation between the squeezing parameter at the optimal time and the $J_z$ component at the same time, for different number of atoms in the ensemble
# - Comparison of analytical expressions and simulations in the cavity removal approximation
# - Evolution of $n$, $J_z$ spin component and $\xi^2$ over time

# ## Correlation between spin and squeezing

# +
def plot_correlation_spin_squeezing(data_dir):
    '''
    Correlation between spin and squeezing in cavity removal simulations
    '''
  
    e = 0.1
    k = 0.1
    g = 0.0247

    keff=k_eff_eps(g,k,e,DeltaOverW0)

    natoms = [400,200,100,50,20]
    color = ['b','xkcd:cerulean','xkcd:kelly green','orange','xkcd:red']

    hist_dir = data_dir+'traj_distribution/correlation/'

    imagewidth=10 # cm
    fig, axs = plt.subplots(1,1,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300)

    iWz_topt = 0
    iJz_topt = 4 

    for i in range(len(natoms)):
        n = natoms[i]
        J = n/2.0
        c = color[i]
        sq_array = np.loadtxt(hist_dir+f'Na{n}/data/ad_sq.txt',unpack='true',usecols=[0,1]).transpose()
        ind_t_min = np.ndarray.argmin(sq_array[:,1])
        tau = keff*sq_array[ind_t_min,0] 

        file=hist_dir+f'Na{n}/data/hist'
        histogram_array = np.loadtxt(file+f'.txt',unpack='true').transpose()
        order=np.argsort(histogram_array[:,iJz_topt])
        axs.plot(np.abs(histogram_array[order,iJz_topt]),histogram_array[order,iWz_topt],label=f'$N={n}$',color=c,zorder=1,linewidth=2.8)

        xi2m = np.average(histogram_array[:,iWz_topt])
        axs.hlines(xi2m,0,1.15/np.sqrt(n),color=c,zorder=0.5,linestyle=":", lw=1)
        axs.vlines(1/np.sqrt(n),1e-2,xi2m*1.15,color=c,zorder=0.4,linestyle=":",lw=1)

        Jzarr=np.linspace(0,1.2*np.max(histogram_array[:,iJz_topt]),100)
        axs.plot(Jzarr, xiAnalyticC(tau,Jzarr**2/(contrastAnalytic(tau)+0*Jzarr**2),J),linestyle='--',   linewidth=0.8 , color=[0.9,0.9,0.9], zorder=1.5)

    axs.set_yscale('log')
    axs.set_xlim((0,0.4))
    axs.set_ylim((1e-2,.5))
    axs.set_xlabel(r'$\left|\langle \hat{J}_z(t = t_m)\rangle_c / J\right|$')
    axs.set_ylabel(r'$\xi^2(t = t_m)$')

    handles, labels = axs.get_legend_handles_labels()
    #specify order of items in legend
    order = [4,3,2,1,0]
    #add legend to plot
    legend=axs.legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc="lower right",frameon=False,fontsize=8)

    legend.get_frame().set_alpha(None)
    axs.tick_params(axis='both',which='both',direction="in")
    save_figure(f'{fig_dir}correlation_spin_squeezing')

plot_correlation_spin_squeezing(data_dir)


# -

# ## Check analytic expressions

# +
def plot_traj(axins,thisdir,g,e,k,Na,Ntraj,ave,legend,showx,showy,panel):
  
    times = ave[i_t,:]
    avesq = ave[i_ad_KU,:]
    # Indexes of observables
    iJX = 0
    iJY = 1
    iJZ = 2

    # Indexes of covariance. Lower triangular matrix, real elements
    iJX2 = 3
    iJYX = 4
    iJY2 = 5
    iJZX = 6
    iJZY = 7
    iJZ2 = 8
    inPh = 9
    J=Na/2
    nph0 = nph_eps(k,e)

    def var(j2,jm):
        return (j2-jm**2)/(J/2)

    def cov(c2,j1,j2):
        return (c2-j1*j2)/(J/2)
  
    def contrast(x,y,z):
        return (x**2+y**2+z**2)/J**2
  
    keff=k_eff_eps(g,k,e,DeltaOverW0)
    
    dir_traj = thisdir+'/dir_traj_temporary/'
  
    for i_traj in range(Ntraj):
        file_name = f'{dir_traj}cav_{i_traj}_ad'
        traj = np.load(file_name+'.npy')
        axins.plot(keff*times,var(traj[iJX2],traj[iJX]),          linewidth=1.5,     color=[0.1,0,0.4], zorder=2.0)
        axins.plot(keff*times,var(traj[iJY2],traj[iJY]),          linewidth=1.5,     color=[0,0.5,0], zorder=2.0)
        axins.plot(keff*times,var(traj[iJZ2],traj[iJZ]),          linewidth=1.5,     color=[0.5,0,0], zorder=2.0)
        axins.plot(keff*times,cov(traj[iJZX],traj[iJX],traj[iJZ]),linewidth=0.3,     color=[0.5,0.5,0], zorder=2.0)
        axins.plot(keff*times,cov(traj[iJZY],traj[iJY],traj[iJZ]),linewidth=0.3,     color=[0.5,0.5,0.5], zorder=2.0)
        axins.plot(keff*times,cov(traj[iJYX],traj[iJX],traj[iJY]),linewidth=0.3,     color=[0.2,0.3,0], zorder=2.0)
        axins.plot(keff*times,contrast(traj[iJX],traj[iJY],traj[iJZ]),linewidth=1.5, color=[0,0.5,0.5], zorder=2.0)

    axins.plot(keff*times, varXAnalytic(keff*times,J),linestyle='--',linewidth=0.7 , color=[1,0,1], zorder=2.2)
    axins.plot(keff*times, varYAnalytic(keff*times,J),linestyle='--',linewidth=0.7 , color=[0,1,0], zorder=2.2)
    axins.plot(keff*times, varZAnalytic(keff*times,J),linestyle='--',linewidth=0.7 , color=[1,0,0], zorder=2.2)
    axins.plot(keff*times,contrastAnalytic(keff*times),linestyle='--',linewidth=0.7,color=[0,1,1], zorder=2.2)
    axins.plot(keff*times, avesq,linestyle='-',                    linewidth=1.5 , color=[0.2,0.2,0.2], zorder=2.2)
    axins.plot(keff*times, xiAnalytic(keff*times,J),linestyle='--',   linewidth=1.2 , color=[0.8,0.8,0.8], zorder=5)

    axins.set_yscale("log")
    axins.set_xscale("log")
    if (legend):
        axins.legend(loc="upper right", framealpha=1, edgecolor="white")
    
    if (showx):
        axins.set_xlabel(r'$\tilde{\kappa} t$')
    else:
        axins.tick_params(axis='x',labelbottom=False)
    
    axins.tick_params(axis='both',which='both',direction="in")
  
    axins.set_xlim((0.004,700*keff))
    axins.set_ylim((0.004,20.0)) 
  
    axins.annotate(xy=(0.05,0.36),text=r'2cov$J_x J_z/J$',xycoords="axes fraction",zorder=7)

    axins.annotate(xy=(0.05,0.79),text=r'$2\Delta^2 J_y/J$',xycoords="axes fraction",zorder=7)
    axins.annotate(xy=(0.58,0.79),text=r'$2\Delta^2 J_x/J$',xycoords="axes fraction",zorder=7)
    axins.annotate(xy=(0.81,0.23),text=r'$2\Delta^2 J_z/J$',xycoords="axes fraction",zorder=7)

    axins.annotate(xy=(0.81,0.46),text=r'E$[\xi^2]$',xycoords="axes fraction",zorder=7)
    axins.annotate(xy=(0.81,0.64),text=r'$\mathcal{C}$',xycoords="axes fraction",zorder=7)

    axins.annotate(xy=(0.06,0.05),text=panel,xycoords="axes fraction",zorder=7)

def plot_correlation():
    dir_name_traj = data_dir + 'traj_distribution/'
    measure = 'homodyne_x'
    Ntraj=100

    imagewidth=10 # cm
    e = 0.1
    k = 0.1
    g = 0.0247

    Na = 160
    thisdir = dir_name_traj+'/inbadcavity/cavrem_160/'

    ad_file = thisdir+'data/ad_sq.txt'
    ave = np.loadtxt(ad_file).transpose() 

    fig, ax = plt.subplots(1,1,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300)
    plot_traj(ax,thisdir,g,e,k,Na,Ntraj,ave,legend=False,showx=True,showy=True,panel="")


plot_correlation()
save_figure(f'{fig_dir}analyticalComparison') 


# -

# ## Trajectories in and out bad-cavity regime

# +
def plot_nph_traj(axins,thisdir,e,k,Na,Ntraj,ad_sq,legend,showx,showy,panel):
    inPh  = 9
    nph0 = nph_eps(k,e)

    dir_traj = thisdir+f'full_{Na}/dir_traj_temporary/'
    for i_traj in range(Ntraj):
        file_name = f'{dir_traj}{i_traj}_ad'
        try:
            traj = np.load(file_name+'.npy')
            axins.plot(DeltaOverW0*ad_sq[it],traj[inPh]/nph0,linewidth=0.2, color=[0.7,0.7,0.7], zorder=2.0)
        except IOError:
            continue
  
    axins.fill_between(DeltaOverW0*ad_sq[it], (ad_sq[iPh]-ad_sq[iPherr])/nph0,(ad_sq[iPh]+ad_sq[iPherr])/nph0, color=[1.0,0.5,0.5],linewidth=0.05, linestyle='--', zorder=2.1)
    axins.plot(DeltaOverW0*ad_sq[it], ad_sq[iPh]/nph0,  label =r'$\mathrm{E}[J_z]$',c='r',linestyle='--',linewidth=1.5, zorder=2.2)

    axins.hlines(1, 0, DeltaOverW0*ad_sq[it,-1], label = 'nph0', color = 'green', linestyle = '-.',linewidth=1.5, zorder=2.3)
    if (legend):
        axins.legend(loc="upper right", framealpha=1, edgecolor="white")

    if (showx):
        axins.set_xlabel(r'$t \Delta$')
    else:
        axins.tick_params(axis='x',labelbottom=False)
    if (showy):  
        axins.set_ylabel(r'$\left\langle\hat{n}\right\rangle^{}_{\!c}/n_0$')
    else:
        axins.tick_params(axis='y',labelleft=False)
    
    axins.tick_params(axis='both',which='both',direction="in")

    axins.set_xlim((0,1000*DeltaOverW0))
    axins.set_ylim((0,1.0)) 
    axins.annotate(xy=(0.06,0.05),text=panel,xycoords="axes fraction",zorder=7)

    
def plot_Jz_traj(axins,thisdir,Na,Ntraj,ad_sq,legend,showx,showy,panel):
    iJZ = 2
    J=Na/2
    dir_j_traj = thisdir+f'full_{Na}/dir_traj_temporary/'

    for i_traj in range(Ntraj):
        file_name = f'{dir_j_traj}{i_traj}_ad'

        try:
            traj_j = np.load(file_name+'.npy')
            axins.plot(DeltaOverW0*ad_sq[it],traj_j[iJZ]/J,linewidth=0.2, color=[0.7,0.7,0.7], zorder=1.6)
        except IOError:
            continue
  
    axins.fill_between(DeltaOverW0*ad_sq[it], ad_sq[iJzi]/J-ad_sq[iJzierr]/J,ad_sq[iJzi]/J+ad_sq[iJzierr]/J, color=[1.0,0.5,0.5],linewidth=0.05, linestyle='--', zorder=1.7)
    axins.plot(DeltaOverW0*ad_sq[it], ad_sq[iJzi]/J,  label =r'$\mathrm{E}[J_z]$',c='r',linestyle='--',linewidth=1.5, zorder=1.8)

    axins.hlines(1/np.sqrt(2*J), 0, DeltaOverW0*ad_sq[it,-1], label = 'std dev. in initial state', color = 'green', linestyle = '-.',linewidth=1.5, zorder=1.9)
    axins.hlines(-1/np.sqrt(2*J), 0, DeltaOverW0*ad_sq[it,-1], color = 'green', linestyle = '-.',linewidth=1.5, zorder=1.9) 
    if (legend):
        axins.legend(loc="upper right", framealpha=1, edgecolor="white")

    if (showx):
        axins.set_xlabel(r'$t \Delta$')
    else:
        axins.tick_params(axis='x',labelbottom=False)
    if (showy):  
        axins.set_ylabel(r'$\langle\hat{J}_z\rangle^{}_{\!c}/J$')
    else:
        axins.tick_params(axis='y',labelleft=False)

    axins.tick_params(axis='both',which='both',direction="in")
    axins.set_xlim((0,1000*DeltaOverW0))
    axins.set_ylim((-0.5,0.5)) 
    axins.annotate(xy=(0.06,0.05),text=panel,xycoords="axes fraction",zorder=6)
    
def plot_xi_trajectories_generic(ax,thisdir,g, e,k,Na,Ntraj,measure,legend,showy,panel):
    ad_file = thisdir+f'full_{Na}/data/ad_sq.txt'
    ad_sq = np.loadtxt(ad_file, usecols=[i_t,i_ad_KU,i_ad_KU_err,i_ad_Photon,i_ad_Ph_err,i_ad_Jz,i_ad_Jz_err]).transpose()
    nocav_file = thisdir+f'cavrem_{Na}/data/ad_sq.txt'
    no_cav = np.loadtxt(nocav_file, unpack='true') 

    ydown, yup = 0.05,1.2
    ixi=1

    xup = 1000
    
    def nphstat(t,n0):
        return n0
  
    keff = k_eff_eps(g,k,e,DeltaOverW0)
    J = Na/2
    
    tm = len(ad_sq[it])
    p0 = (nph_eps(k,e))
    parameters, covariance = curve_fit(nphstat, ad_sq[it,tm//2:tm], ad_sq[iPh,tm//2:tm], p0, sigma=ad_sq[iPherr,tm//2:tm], absolute_sigma=True)
    n0 = parameters[0]
    idx = find_nearest(ad_sq[iPh,:], n0*0.9)
    offset = ad_sq[it,idx] # OFFSET= when number of photons is 90% of stationary value 
    
    dir_sq_traj = thisdir+f'full_{Na}/squeezing_traj/'
    for i_traj in range(Ntraj):
        file_name = f'squeezing_'
        try:
            traj_sq = np.load(f'{dir_sq_traj}{file_name}{i_traj}.npy')
            # select only values which are under upper bound (to avoid "overflow" of figure)
            plot_traj = np.array([ad_sq[it],traj_sq]) 
            plot_traj=plot_traj [:,(plot_traj[ixi,:]<=yup+1)] 
            plot_traj=plot_traj [:,(plot_traj[0,:]<=xup+10)]

            ax.plot(DeltaOverW0*plot_traj[it],plot_traj[ixi],linewidth=0.2, color=[0.7,0.7,0.7], zorder=0.5)
        except IOError:
            continue

    ax.hlines(1,DeltaOverW0*ad_sq[it,0],DeltaOverW0*ad_sq[it,-1],label='SQL',color='k', zorder=1,linestyle="--")
    ax.fill_between(DeltaOverW0*ad_sq[it], ad_sq[iXi]-ad_sq[iXierr], ad_sq[iXi]+ad_sq[iXierr], color=[1.0,0.5,0.5],linewidth=0.05, zorder=1.1)
    ax.plot(DeltaOverW0*ad_sq[it], ad_sq[iXi],  label =r'$\mathrm{E}[\xi^2(t)]$',c='r',linewidth=1.5,linestyle='--', zorder=1.2)
    ax.plot(DeltaOverW0*(ad_sq[it]+offset), Mancini_squeezing(ad_sq[it], Na, keff),label =r'$\xi^2_{F}$ (time offset)',c='b',linewidth=1.5,linestyle='-.', zorder=1.3)

    ax.plot(DeltaOverW0*(no_cav[i_t]+offset),no_cav[i_ad_KU],c='purple',linestyle=':',label=r'$\mathrm{E}[\xi^2(t)]$ Cavity removal (time offset)',linewidth=1.5,zorder=1.4)
    ttt=np.linspace(0,1000,200)
    ax.plot(DeltaOverW0*(ttt+offset), xiAnalytic(keff*ttt,J),linestyle='-',   linewidth=1.5 , color=[0.0,0.5,0.0], zorder=1.25)

    ax.set_yscale('log')
    ax.set_ylim((ydown,yup))
    ax.set_xlim((0,xup*DeltaOverW0))
    ax.set_xlabel(r'$t \Delta$', labelpad=1.5)
    if (showy):
        ax.set_ylabel(r'$\xi^2$')
    else:
        ax.tick_params(axis='y',labelleft=False)

    handles, labels = ax.get_legend_handles_labels()
    order = [3,0,1,2]
    if (legend):
        ax.legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc="lower left",frameon=False)

    ax.tick_params(axis='both',which='both',direction="in")
    ax.annotate(xy=(0.06,0.05),text=panel,xycoords="axes fraction",zorder=9)

    return ad_sq


def plot_3_2_traj():
    dir_name_traj = data_dir + 'traj_distribution/'
    measure = 'homodyne_x'
    g = 0.0247 
    Na=45
    Ntraj=400

    imagewidth=24 # cm
    fig,axs = plt.subplots(3,2,figsize=[imagewidth*cm,imagewidth*cm/goldenratio], dpi=300, gridspec_kw={'height_ratios': [0.6,1,2.5]})
    fig.subplots_adjust(wspace=0.05, hspace=0.15)

    e = 0.2
    k = 0.2
    thisdir= dir_name_traj+'inbadcavity/'
    ad_sqb=plot_xi_trajectories_generic(axs[2,0],thisdir,g, e,k,Na,Ntraj,measure,legend=False,showy=True,panel="c)")
    plot_Jz_traj(axs[1,0],thisdir,Na,Ntraj,ad_sqb,legend=False,showx=False,showy=True,panel="b)")
    plot_nph_traj(axs[0,0],thisdir,e,k,Na,Ntraj,ad_sqb,legend=False,showx=False,showy=True,panel="a)")
    axs[2,0].annotate(r'SQL',(200,0.8))

    e = 0.02
    k = 0.02
    thisdir= dir_name_traj+'outbadcavity/'
    ad_sqg = plot_xi_trajectories_generic(axs[2,1], thisdir,g, e,k,Na,Ntraj,measure,legend=False,showy=False,panel="f)")
    plot_Jz_traj(axs[1,1],thisdir,Na,Ntraj,ad_sqg,legend=False,showx=False,showy=False,panel="e)")
    plot_nph_traj(axs[0,1],thisdir,e,k,Na,Ntraj,ad_sqg,legend=False,showx=False,showy=False,panel="d)")
    fig.align_labels()
    

plot_3_2_traj()
save_figure(f'{fig_dir}xi_trajectories_all_ph') 