import numpy as np
import matplotlib.pyplot as plt
import json
import matplotlib as mpl
from string import ascii_lowercase

# ---------------- system parameters ---------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.1  # dimensionless lattice constant
d=d0*lambda0
hermitian=True # if diagonalized the Hermitian part of the Hamiltonian 
D0=1.6 # system size 
D=D0*d
simplify=False #use simplified model instead of the full model
Vquad0=3 # rescaled confining potential strength
Npart=2 # number of particles (note: some parts of the code work for two particles only)
Lmax=6 # Lmax-fold rotational symmetry
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field
w0=1 # waist parameter of the LG modes
Lgs=2 # will compute emission from this subspace
fsize=16 #fontsize 

#set plot fonts
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'

#load f's
proj=np.loadtxt("modeprojections_OBC_d0={:.4f}_herm{}_simp{}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_rtrunc={:.4f}.txt".format(d0, hermitian,simplify, D0,1,B0,Vquad0,rtrunc))
lproj=proj[:,0].astype(int)


# load output light intensities 
f= open("output_OBC_d0={:.4f}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_Lgs{}.json".format(d0, D0,Npart,B0,Vquad0, Lgs))
data_dict=json.load(f)
f.close()
sp=data_dict["single photon emission"]
ls=sp["ls"]
i0=ls.index(0)
Nstate=2 #how many states to plot (rescaled intensity plots)

if (ls==list(lproj)): # check if both files have the same list of l (otherwise the code will crash)
    scale_minus=proj[:,1]

# set figure, axes etc.
fig=plt.figure(figsize=(6, 7.3))
gs = mpl.gridspec.GridSpec(2,1, figure=fig, height_ratios=[1,3],left=0.14, top=0.97, bottom=0.07, hspace=0.4, right=0.98)
gs1 = gs[0, 0].subgridspec(1,Nstate)
ax1 = gs1.subplots()
ax=ax1

################################################################
##  plot rescaled intensities ( (a), (b) )
################################################################

for istate,state in enumerate(sp["data"]):
  if(istate<Nstate):
    ax[istate].bar(ls, state["occ_minus"]/scale_minus)
    ax[istate].set_xlabel(r"$l$", fontsize=fsize)
  if(istate==0):
    ax[istate].set_ylabel(r"$\tilde{I}^{-}_l$", fontsize=fsize)

################################################################
##  orbital occupations  ( (a), (b) )
################################################################

inds=[]
for l in ls:
  L=(1+l)%Lmax
  inds.append(L)

occ=data_dict["occ"]
for istate,state in enumerate(occ["data"]):
  if(istate<Nstate):
    occ0=state["occ"]
    occ1=[]
    for i in inds:
      occ1.append(occ0[i])
    ax[istate].bar(ls, occ1,fill=False, edgecolor="r")

# set labels etc. for (a), (b)

ax[0].set_xlim(-3.5, 3.5)
ax[1].set_xlim(-3.5, 3.5)
ax[0].set_ylim(0, 2)
ax[1].set_ylim(0, 2)
ax[1].set_yticklabels([])

ax[0].text(-0.32, 0.95, "(a)", transform=ax[0].transAxes, fontsize=fsize)
ax[1].text(-0.16, 0.95, "(b)", transform=ax[1].transAxes, fontsize=fsize)

ax[0].tick_params(axis='both', which='major', labelsize=fsize)
ax[1].tick_params(axis='both', which='major', labelsize=fsize)

###############################################################################
#  plot signatures ( (c) -(f) )
################################################################################


gs1 = gs[1, 0].subgridspec(4,1)
ax2 = gs1.subplots()
ax=ax2
d0s=[0.1, 0.2,0.3, 0.4]
Vquad0s=[3.0, 10, 10,10]
styles=["solid", "dashed", "dotted", "dashdot"]
for id0, d0 in enumerate(d0s):
  f= open("output_OBC_d0={:.4f}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_Lgs{}.json".format(d0, D0,Npart,B0,Vquad0s[id0], Lgs))
  data_dict=json.load(f)
  f.close()
  # -------- get signatures based on single-photon emissions --------
  sp=data_dict["single photon emission"]
  ls=sp["ls"]
  i0=ls.index(0)
  Nstate=len(sp["data"])
  A=[]
  B=[]
  C=[]
  E=[]
  for istate,state in enumerate(sp["data"]):
    A.append(state["occ_minus"][i0-1]/state["occ_minus"][i0+1]) #S1
    B.append(state["occ_minus"][i0-2]/state["occ_minus"][i0-1]) #S2
    C.append(state["occ_minus"][i0+2]/state["occ_minus"][i0+1]) #S2'
    E.append(state["occ_plus"][i0+1]/state["occ_plus"][i0])     #S4
    
  ax[0].plot(range(len(sp["data"])), A,linestyle=styles[id0],marker="o", label="$d_0={}$".format(d0))
  ax[0].axhspan(0.65, 1.05, facecolor="gray", alpha=0.1) # add the gray region
  ax[1].set_yscale("log")
  if(id0==0):
    ax[1].scatter(range(len(sp["data"])), B, c="C{}".format(id0), marker="+", label=r"$S_2$")   #plot S2
    ax[1].scatter(range(len(sp["data"])), C, c="C{}".format(id0), marker="x", label=r"$S_2'$")  #plot S2'
  else:
    ax[1].scatter(range(len(sp["data"])), B, c="C{}".format(id0), marker="+") #plot S2
    ax[1].scatter(range(len(sp["data"])), C, c="C{}".format(id0), marker="x") #plot S2'
  ax[1].axhspan(10**(-7), 2*10**(-3), facecolor="gray", alpha=0.1) # add the gray region
  ax[3].plot(range(len(sp["data"])), E, marker="o", linestyle=styles[id0]) #plot s4
  ax[3].axhspan(1, 6, facecolor="lightblue", alpha=0.1)
  # -------- get signature S3 --------
  mp=data_dict["two-photon emission"]
  lpairs=mp["lpairs"]
  D=[]
  for istate,state in enumerate(mp["data"]):
    aph=np.array(state["amplitude phase"])
    aph=(aph+0.02) % (2*np.pi)-0.02
    D.append(aph[1])
  ax[2].plot(range(len(sp["data"])), D, marker="o", linestyle=styles[id0])
  ax[2].axhspan(3, 3.3, facecolor="gray", alpha=0.1)

# --------- adjust plot details -----------

ax[0].set_ylim(0,2)
ax[3].set_ylim(0,6)
ax[1].set_ylim(10**-7,10)
ax[3].set_xlabel("Eigenstate index", fontsize=fsize, labelpad=-2)
ax[0].set_ylabel("$S_1$", fontsize=fsize)
ax[1].set_ylabel("$S_2, S_2'$", fontsize=fsize, labelpad=-3)
ax[1].set_yticks([10**(-6),10**(-3),10**(0)])
ax[2].set_ylabel("$S_3$", fontsize=fsize)
ax[2].set_yticks([0,np.pi])
ax[2].set_yticklabels([0,r"$\pi$"], fontsize=fsize)
ax[3].set_ylabel("$S_4$", fontsize=fsize)

ax[0].tick_params(axis='both', which='major', labelsize=fsize)
ax[1].tick_params(axis='both', which='major', labelsize=fsize)
ax[2].tick_params(axis='both', which='major', labelsize=fsize)
ax[3].tick_params(axis='both', which='major', labelsize=fsize)

for i in range(4):
  ax[i].text(-0.15, 0.9, "("+ascii_lowercase[i+2]+")", transform=ax[i].transAxes, fontsize=fsize)

for i in range(3):
  ax[i].set_xticklabels([])

ax[0].legend(ncol=4, bbox_to_anchor=[0,1.5,1.03,0.2], fontsize=fsize, columnspacing=0.5, handletextpad=0.1, handlelength=1.9, loc=1)
ax[1].legend(ncol=2, loc="lower right", fontsize=fsize, columnspacing=0.5, handletextpad=0.1, handlelength=0.4, borderpad=0.1, borderaxespad=0.1)

plt.savefig("Figure7.pdf")
plt.show()
