import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.image as mpimg
import json
from string import ascii_lowercase

#------------------------- system parameters ------------------------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
hermitian=True # if diagonalize the Hermitian part of the Hamiltonian (keep True otherwise the orbitals will be nonorthogonal) 
D0=0.6 # system size 
Vquad0=0 # rescaled confining potential strength (not useful for D0=0.6)
Npart=2 # maximum number of particles (note: some parts of the code work for two particles only)
Lmax=6 # Lmax-fold rotational symmetry
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field
w0=1 # waist parameter of the LG modes
omega=0.1 # Rabi frequency in the units of gamma0
#drive="gaussian"
drive="uniform" #type of drive (gaussian or uniform)
Ltarget=1 # single-particle angular momentum sector that is accessed by driving
pol=0 # polarization (0=-, 1=+)
w0=1 # w0 of the measurement Gaussian modes
fsize=13 # font size for the plot

# write polarization as a vector
if(pol==0):
  pol2=[1,0]
if(pol==1):
  pol2=[0,1]

# set plot fonts
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'


# part of the filename depending on the type of driving
if(drive=="gaussian"):
  l=Ltarget+(2*pol-1)
  drivestr="g_l={}_w={:.2f}".format(l,w0)
if(drive=="uniform"):
  drivestr="u"

# generate axes
fig,ax=plt.subplots(4,1, sharex=True, figsize=(5,6))
fig2 = plt.figure(figsize=(5,5))
spec2 = gridspec.GridSpec(ncols=2, nrows=2, figure=fig2)
axnexc1 = fig2.add_subplot(spec2[0, 0])
axnexc2 = fig2.add_subplot(spec2[0, 1])
axov = fig2.add_subplot(spec2[1, 0])
axS5 = fig2.add_subplot(spec2[1, 1])

#the |c_2|^2 will be plotted in the units of scale_nexc2
exponent_nexc2=-4
scale_nexc2=10**exponent_nexc2

for id0, d0 in enumerate([0.1,0.2, 0.3, 0.4]):
  #read file
  fname="driving_ang_pol_"+drivestr+"_omega={:.4f}_pol{:.4f}_{:.4f}_Ltarget={}_d0={:.4f}_D0={:.4f}_B0={:.4f}_Vquad0={:.4f}.json".format(omega, pol2[0], pol2[1],Ltarget, d0,D0,B0,Vquad0)
  f=open(fname)
  data=json.load(f)
  f.close()
  #extract data from file
  d=d0*lambda0
  deltas=np.array(data["deltas"])
  deltas0=deltas*d**3*k0**3
  ndeltas=len(deltas)
  part=len(data["subspace_results"])
  subs=data["subspace_results"][1]
  state=subs["states"][0]
  overlaps=np.array(state["overlaps"]) # plot overlaps
  axov.plot(deltas0, overlaps)
  #calculate and plot S5
  # get I^-_0 for the calculation of S5
  I0tot=np.zeros(len(deltas))
  for subs in data["subspace_results"]:
    for imode, mode in enumerate(subs["single-photon intensity"]):
      if(mode["l"]==0)and(mode["pol"]==[1,0]):
        I0tot+=mode["intensity"]
  twopart_data=data["subspace_results"][1]
  lpair0=twopart_data["two-photon amplitudes"][0]
  I00=np.abs(np.array(lpair0["real"])+1j*np.array(lpair0["imag"]))**2
  S5=I00/I0tot
  axS5.plot(deltas0, S5)
  # plot signatures based on single-mode intensitiies 
  modes_minus=np.zeros((5, ndeltas))
  for subs in data["subspace_results"]:
    for imode, mode in enumerate(subs["single-photon intensity"]):
      if(mode["l"]<3)and(mode["l"]>-3)and(mode["pol"]==[1,0]):
          modes_minus[int(mode["l"]),:]=mode["intensity"]
  ax[0].plot(deltas0, modes_minus[-1,:]/modes_minus[1,:])
  ax[1].plot(deltas0, modes_minus[-2,:]/modes_minus[-1,:], c="C{}".format(id0), label="I^{(-)}_{-2}/I^{(-)}_{-1}")
  ax[1].plot(deltas0, modes_minus[2,:]/modes_minus[1,:], c="C{}".format(id0), label="I^{(-)}_{2}/I^{(-)}_{1}", linestyle="--")
  # calculate and plot S3
  lpair1=twopart_data["two-photon amplitudes"][1]
  A0=np.array(lpair0["real"])+1j*np.array(lpair0["imag"])
  A1=np.array(lpair1["real"])+1j*np.array(lpair1["imag"])
  phase=np.angle(A1/A0)
  ax[2].plot(deltas0,phase)
  # calculate and plot S4
  modes_plus=np.zeros((2, ndeltas))
  for subs in data["subspace_results"]:
    for imode, mode in enumerate(subs["single-photon intensity"]):
      if(mode["l"]<2)and(mode["l"]>-1)and(mode["pol"]==[0,1]):
          modes_plus[int(mode["l"]),:]=mode["intensity"]
  ax[3].plot(deltas0, modes_plus[1,:]/modes_plus[0,:],label="$d_0={}$".format(d0))
  # find the first local maximum of overlapi, plot it on the subplots as dotted vertical line
  for i in range(ndeltas-1):
    if (overlaps[i-1]<overlaps[i])and(overlaps[i+1]<overlaps[i]):
      lmax=deltas0[i]
      break
  for i in range(4):
      ax[i].axvline(lmax, c="C{}".format(id0), linestyle=":")
  axS5.axvline(lmax, c="C{}".format(id0), linestyle=":")
  # plot |c_1|^2, |c_2|^2
  nexc1=data["subspace_results"][0]["nexc"]
  nexc2=data["subspace_results"][1]["nexc"]
  axnexc2.plot(deltas0, np.array(nexc2)/scale_nexc2, label="$d_0={}$".format(d0), zorder=4) #/np.max(nexc2))
  axnexc1.plot(deltas0, nexc1)
  alpha=0.7
  # draw single-particle ground state energy as dashed vertical line
  axnexc1.axvline(data["subspace_results"][0]["states"][0]["E"]*d**3*k0**3, c="C{}".format(id0), alpha=alpha, linestyle="--")
  axnexc2.axvline(data["subspace_results"][0]["states"][0]["E"]*d**3*k0**3, c="C{}".format(id0), alpha=alpha, linestyle="--")
  axov.axvline(data["subspace_results"][0]["states"][0]["E"]*d**3*k0**3, c="C{}".format(id0), alpha=alpha, linestyle="--")

# set labels, ranges etc.
axnexc2.set_ylabel(r"$|c_2|^2$", fontsize=fsize)
axnexc1.set_ylabel(r"$|c_1|^2$", fontsize=fsize, labelpad=-2)
axov.set_ylabel(r"$|\langle \psi_0 | \psi_\mathrm{steady}^{(2)} \rangle|^2$", fontsize=fsize, labelpad=1)
axS5.set_ylabel(r"$S_5$", fontsize=fsize, labelpad=-5)
axnexc2.set_xlabel(r"$\Delta_0$", fontsize=fsize, labelpad=1)
axnexc1.set_xlabel(r"$\Delta_0$", fontsize=fsize, labelpad=1)
axov.set_xlabel(r"$\Delta_0$", fontsize=fsize, labelpad=1)
axS5.set_xlabel(r"$\Delta_0$", fontsize=fsize, labelpad=1)
axnexc1.set_ylim(-0.003,0.03 )
axnexc2.set_ylim(-0.5,5.5 )
axnexc2.text(-0.05, 1.00,  r"$\times 10^{{{}}}$".format(exponent_nexc2)   ,  transform=axnexc2.transAxes, fontsize=fsize)
axnexc1.text(-0.4, 0.99, "(a)", transform=axnexc1.transAxes, fontsize=fsize)
axnexc2.text(-0.3, 0.99, "(b)", transform=axnexc2.transAxes, fontsize=fsize)
axov.text(-0.4, 1, "(c)", transform=axov.transAxes, fontsize=fsize)
axS5.text(-0.3, 1, "(d)", transform=axS5.transAxes, fontsize=fsize)
axnexc1.tick_params(axis='both', which='major', labelsize=fsize)
axnexc2.tick_params(axis='both', which='major', labelsize=fsize)
axov.tick_params(axis='both', which='major', labelsize=fsize)
axS5.tick_params(axis='both', which='major', labelsize=fsize)
axnexc2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
axnexc2.yaxis.offsetText.set_fontsize(fsize)
axnexc2.set_xlim(-25,-5 )
axov.set_xlim(-36,5 )
axS5.set_xlim(-36,5 )
ax[0].set_ylabel(r"$S_1$", fontsize=fsize)
ax[1].set_ylabel(r"$S_2$", fontsize=fsize, labelpad=-2)
ax[2].set_ylabel(r"$S_3$", fontsize=fsize, labelpad=-4)
ax[3].set_ylabel(r"$S_4$", fontsize=fsize)
ax[3].set_xlabel(r"$\Delta_0$", fontsize=fsize, labelpad=1)
ax[0].set_ylim(0,3)
ax[3].set_ylim(-0.1,4)
ax[1].set_yscale("log")
ax[0].axhspan(0.75, 1.05, facecolor="gray", alpha=0.3)
ax[1].axhspan(0, 10**(-3), facecolor="gray", alpha=0.3)
ax[2].axhspan(-3, -3.3, facecolor="gray", alpha=0.3)
ax[2].axhspan(2*np.pi-3, 2*np.pi-3.3, facecolor="gray", alpha=0.3)
ax[3].axhspan(1, 4, facecolor="lightblue", alpha=0.3)
ax[2].set_yticks([-np.pi, 0, np.pi])
ax[2].set_yticklabels([r"$-\pi$", 0, r"$\pi$"])
for i in range(4):
  ax[i].axvline(-12.844,c="C3") # draw a vertical line denoting maximuum of S5 for d_0=0.4
for i in range(4):
  ax[i].text(-0.16, 0.9, "("+ascii_lowercase[i]+")", transform=ax[i].transAxes, fontsize=fsize)
  ax[i].tick_params(axis='both', which='major', labelsize=fsize)
plt.figure(fig)
plt.subplots_adjust(left=0.14, right=0.99, bottom=0.14, top=0.98, hspace=0.2)
ax[3].legend(ncol=4, loc="center",bbox_to_anchor=[-0.0, -0.38,1,-0.4], fontsize=fsize, handlelength=1, handletextpad=0.2, columnspacing=0.5)
plt.savefig("Figure11.pdf")
plt.figure(fig2)
fig2.subplots_adjust(left=0.14, right=0.99, bottom=0.17, top=0.95, wspace=0.42, hspace=0.35)
axnexc2.legend(ncol=4, loc="center",bbox_to_anchor=[-1.26, -1.84,2,0.2], fontsize=fsize, handlelength=1, handletextpad=0.2, columnspacing=0.5)
plt.savefig("Figure9.pdf")
plt.show()
