import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import json
from string import ascii_lowercase
#------------------------- system parameters ------------------------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.3  # dimensionless lattice constant
d=d0*lambda0
D0=1.6 # system size 
D=D0*d
Vquad0=10 # rescaled confining potential strength (not useful for D0=0.6)
Npart=2 # maximum number of particles (note: some parts of the code work for two particles only)
Lmax=6 # Lmax-fold rotational symmetry
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field
w0=1 # waist parameter of the LG modes
Lgs=2 # will compute emission from this subspace
omega=0.1 # Rabi frequency in the units of gamma0
dip1=1 # dipole element of state -
dip2=1 # dipole element oi state +
#drive="gaussian"
drive="uniform" #type of drive (gaussian or uniform)
Ltarget=1 # single-particle angular momentum sector that is accessed by driving
pol=0 # polarization (0=-, 1=+)
w0=1 # w0 of the measurement Gaussian modes
fsize=13 # font size for the plot

B=B0/d**3/k0**3 # non-rescaled magnetic field
Vquad=Vquad0/d**5/k0**3 # non-rescaled harmonic potential

# write polarization as a vector
if(pol==0):
  pol2=[1,0]
if(pol==1):
  pol2=[0,1]

# set plot fonts
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'
fig,ax=plt.subplots(4,2, sharex=True, figsize=(5,6))

# the part of filename depending on the driving type
if(drive=="gaussian"):
  l=Ltarget+(2*pol-1)
  drivestr="g_l={}_w={:.2f}".format(l,w0)
if(drive=="uniform"):
  drivestr="u"

# open the file with the driven system results:
fname="driving_ang_pol_"+drivestr+"_omega={:.4f}_pol{:.4f}_{:.4f}_Ltarget={}_d0={:.4f}_D0={:.4f}_B0={:.4f}_Vquad0={:.4f}.json".format(omega, pol2[0], pol2[1],Ltarget, d0,D0,B0,Vquad0)
f=open(fname)
data=json.load(f)
f.close()
#read data from the file
d=d0*lambda0
deltas=np.array(data["deltas"])
deltas0=deltas*d**3*k0**3
ndeltas=len(deltas)
part=len(data["subspace_results"])
subs=data["subspace_results"][1]
state=subs["states"][0]
overlaps=np.array(state["overlaps"])
# plot overlaps
exponent=-2
scale=10**(exponent)
ax[1,0].plot(deltas0, np.array(overlaps)/scale)
ax[1,0].text(0.0, 1.005,  r"$\times 10^{{{}}}$".format(exponent)   ,  transform=ax[1,0].transAxes, fontsize=fsize) 
# compute and plot S5
I0tot=np.zeros(len(deltas))
for subs in data["subspace_results"]:
  for imode, mode in enumerate(subs["single-photon intensity"]):
    if(mode["l"]==0)and(mode["pol"]==[1,0]):
      I0tot+=mode["intensity"]
twopart_data=data["subspace_results"][1]
lpair0=twopart_data["two-photon amplitudes"][0]
I00=np.abs(np.array(lpair0["real"])+1j*np.array(lpair0["imag"]))**2
S5=I00/I0tot
ax[2,0].plot(deltas0, S5)
# compute and plot S1, S2,S2'
modes_minus=np.zeros((5, ndeltas))
for subs in data["subspace_results"]:
  for imode, mode in enumerate(subs["single-photon intensity"]):
    if(mode["l"]<3)and(mode["l"]>-3)and(mode["pol"]==[1,0]):
        modes_minus[int(mode["l"]),:]=mode["intensity"]
ax[2,1].plot(deltas0, modes_minus[-1,:]/modes_minus[1,:])
ax[3,0].plot(deltas0, modes_minus[-2,:]/modes_minus[-1,:], c="C0")
ax[3,0].plot(deltas0, modes_minus[2,:]/modes_minus[1,:], c="C0", linestyle="--")
# compute and plot S3
lpair1=twopart_data["two-photon amplitudes"][1]
A0=np.array(lpair0["real"])+1j*np.array(lpair0["imag"])
A1=np.array(lpair1["real"])+1j*np.array(lpair1["imag"])
phase=np.angle(A1/A0)
ax[3,1].plot(deltas0,phase,label="$d_0={}$".format(d0))
# find the first local maximum of overlap, plot it as a dotted vertical line
for i in range(ndeltas-1):
  if (overlaps[i-1]<overlaps[i])and(overlaps[i+1]<overlaps[i]):
    lmin=deltas0[i]
    break
for i in range(4):
  for j in range(2):
    ax[i,j].axvline(lmin, c="C0", linestyle=":")
# load and plot |c_n|^2
nexc1=data["subspace_results"][0]["nexc"]
exponent=-2
scale=10**(exponent)
ax[0,0].plot(deltas0, np.array(nexc1)/scale) #np.max(nexc1))
ax[0,0].text(0.0, 1.005,  r"$\times 10^{{{}}}$".format(exponent)   ,  transform=ax[0,0].transAxes, fontsize=fsize)
nexc2=data["subspace_results"][1]["nexc"]
exponent=-3
scale=10**(exponent)
ax[0,1].plot(deltas0, np.array(nexc2)/scale, label="$d_0={}$".format(d0)) #/np.max(nexc2))
ax[0,1].text(0.0, 1.005,  r"$\times 10^{{{}}}$".format(exponent)   ,  transform=ax[0,1].transAxes, fontsize=fsize)



#------ compute and plot rescaled two-photon intensities -----------
# compute rescaled two-photon intensities
proj=np.loadtxt("modeprojections_OBC_d0={:.4f}_herm{}_simp{}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_rtrunc={:.4f}.txt".format(d0, True,False, D0,1,B0,Vquad0,rtrunc))
lproj=proj[:,0].astype(int)
twopart_data=data["subspace_results"][1]
lpair0=twopart_data["two-photon amplitudes"][0]
lpair1=twopart_data["two-photon amplitudes"][1]
colors = mpl.colormaps.get_cmap('inferno').resampled(5).colors
coeffs=[]
for ipair,lpair in enumerate(twopart_data["two-photon amplitudes"]):
  twopart_amp=np.array(lpair["real"])+1j*np.array(lpair["imag"])
  if(ipair<4):
    l1,l2=lpair["lpair"]
    ind1=list(lproj).index(l1)
    ind2=list(lproj).index(l2)
    coeff=np.abs(twopart_amp)**2/proj[ind1,1]/proj[ind2,1]
    if(l1==l2):
      coeff/=2
    coeffs.append(coeff)
coeffs_norm=sum(coeffs)
# plot rescaled two-photon intensities
exponent=-3
scale=10**(exponent)
for ipair,lpair in enumerate(twopart_data["two-photon amplitudes"]):
  if(ipair<4):
    l1,l2=lpair["lpair"]
    ax[1,1].plot(deltas0, coeffs[ipair]/scale, label=r"${}, {}$".format(l1, l2), c=colors[ipair])
ax[1,1].legend(loc="upper right", fontsize=fsize, labelspacing=0, handlelength=0.5, handletextpad=0.2,borderaxespad=0.)
ax[1,1].text(0.0, 1.005,  r"$\times 10^{{{}}}$".format(exponent)   ,  transform=ax[1,1].transAxes, fontsize=fsize)

# fix plot details
ax[0,1].set_ylabel(r"$|c_2|^2$", fontsize=fsize, labelpad=1)
ax[0,0].set_ylabel(r"$|c_1|^2$", fontsize=fsize, labelpad=1)
ax[1,1].set_ylabel(r"$\tilde{I}_{l_1,l_2}$", fontsize=fsize, labelpad=-1)
ax[1,0].set_ylabel(r"$|\langle \psi_0 | \psi_\mathrm{steady}^{(2)} \rangle|^2$", fontsize=fsize, labelpad=1)
ax[0,0].set_ylim(-0.8,8 )
ax[0,1].set_ylim(-0.3,3 )
ax[2,0].set_ylim(-0.04,0.4 )
ax[1,1].set_ylim(-0.3,3 )
ax[2,1].set_ylim(0,1.5 )
ax[0,0].set_xlim(-29,0 )
ax[2,0].set_ylabel(r"$S_5$", fontsize=fsize, labelpad=1)
ax[2,1].set_ylabel(r"$S_1$", fontsize=fsize, labelpad=1)
ax[3,0].set_ylabel(r"$S_2, S_2'$", fontsize=fsize, labelpad=-1)
ax[3,1].set_ylabel(r"$S_3$", fontsize=fsize, labelpad=-8)
ax[3,0].set_xlabel(r"$\Delta_0$", fontsize=fsize)
ax[3,1].set_xlabel(r"$\Delta_0$", fontsize=fsize)
ax[3,0].set_yscale("log")
ax[2,1].axhspan(0.65, 1.05, facecolor="gray", alpha=0.2)
ax[3,0].axhspan(0, 2*10**(-3), facecolor="gray", alpha=0.2)
ax[3,1].axhspan(-3, -3.3, facecolor="gray", alpha=0.2)
ax[3,1].axhspan(2*np.pi-3, 2*np.pi-3.3, facecolor="gray", alpha=0.2)
k=0
txs=[-0.39, -0.33]
for i in range(4):
 for j in range(2):
  ax[i,j].text(txs[j], 1.03, "("+ascii_lowercase[k]+")", transform=ax[i,j].transAxes, fontsize=fsize)
  ax[i,j].tick_params(axis='both', which='major', labelsize=fsize)
  ax[i,j].yaxis.offsetText.set_fontsize(fsize)
  k+=1
ax[0,0].ticklabel_format(style='sci', axis='y', scilimits=(0,0))
plt.subplots_adjust(left=0.14, right=0.98, bottom=0.08, top=0.97, wspace=0.36, hspace=0.2)
plt.savefig("Figure18.pdf")
plt.show()
