import numpy as np
import matplotlib.pyplot as plt
import json
import matplotlib as mpl
from string import ascii_lowercase

lambda0=1
k0=2*np.pi/lambda0
D0=0.6 
Vquad0=0.0 #1.0 #000
Lmax=6
Npart=2
Lgs=2
B0=12
fsize=16

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'


####################################################################################################3
# Signatures as a function of eigenstate index (first column)
####################################################################################################


d0s=[0.1, 0.2, 0.3]
fig, ax = plt.subplots(5,2, figsize=(6,6))
plt.subplots_adjust(top=0.98, bottom=0.1,left=0.125, right=0.98, wspace=0.15)
styles=["solid", "dashed", "dotted", "dashdot", (0, (3, 5, 1, 5, 1, 5))] # define line styles
for id0, d0 in enumerate(d0s):
  f= open("output_OBC_d0={:.4f}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_Lgs{}.json".format(d0, D0,Npart,B0,Vquad0, Lgs)) # load data file
  data_dict=json.load(f)
  f.close()
  # get signatures based on single-photon emission
  sp=data_dict["single photon emission"]
  ls=sp["ls"]
  i0=ls.index(0)
  Nstate=len(sp["data"])
  A=[]
  B=[]
  C=[]
  E=[]
  for istate,state in enumerate(sp["data"]):
    A.append(state["occ_minus"][i0-1]/state["occ_minus"][i0+1]) #S1
    B.append(state["occ_minus"][i0-2]/state["occ_minus"][i0-1]) #S2
    C.append(state["occ_minus"][i0+2]/state["occ_minus"][i0+1]) #S2'
    E.append(state["occ_plus"][i0+1]/state["occ_plus"][i0])     #S4 
    
  ax[1,0].plot(range(len(sp["data"])), A,linestyle=styles[id0],marker="o", label="$d_0={}$".format(d0))
  ax[2,0].set_yscale("log")
  # plot S2, S2', add labels only once so that they differ in shape but not colour
  if(id0==0):
    ax[2,0].scatter(range(len(sp["data"])), B, c="C{}".format(id0), marker="+", label=r"$S_2$")
    ax[2,0].scatter(range(len(sp["data"])), C, c="C{}".format(id0), marker="x", label=r"$S_2'$")
  else:
    ax[2,0].scatter(range(len(sp["data"])), B, c="C{}".format(id0), marker="+")
    ax[2,0].scatter(range(len(sp["data"])), C, c="C{}".format(id0), marker="x")
  # get signature S3
  mp=data_dict["two-photon emission"]
  lpairs=mp["lpairs"]
  D=[]
  for istate,state in enumerate(mp["data"]):
    aph=np.array(state["amplitude phase"])
    aph=(aph+0.02) % (2*np.pi)-0.02
    D.append(aph[1])
  ax[3,0].plot(range(len(sp["data"])), D, marker="o", linestyle=styles[id0])
  ax[4,0].plot(range(len(sp["data"])), E, marker="o", linestyle=styles[id0])


# set parameters of the plot, the gray/blue regions etc

ax[1,0].axhspan(0.65, 1.05, facecolor="gray", alpha=0.3)
ax[2,0].axhspan(10**(-7), 2*10**(-3), facecolor="gray", alpha=0.3)
ax[3,0].axhspan(3, 3.3, facecolor="gray", alpha=0.3)
ax[4,0].axhspan(1, 6, facecolor="lightblue", alpha=0.3)
ax[1,1].axhspan(0.65, 1.05, facecolor="gray", alpha=0.3)
ax[2,1].axhspan(10**(-7), 2*10**(-3), facecolor="gray", alpha=0.3)
ax[3,1].axhspan(3, 3.3, facecolor="gray", alpha=0.3)
ax[4,1].axhspan(1, 6, facecolor="lightblue", alpha=0.3)

ax[1,0].set_ylim(0,2)
ax[1,1].set_ylim(0,2)
ax[2,0].set_ylim(10**(-5),10**(0))
ax[2,1].set_ylim(10**(-5),10**(0))
ax[3,0].set_ylim(-0.3,3.5)
ax[3,1].set_ylim(-0.1,3.4)
ax[4,0].set_ylim(0,6)
ax[4,1].set_ylim(0,6)
ax[1,0].legend(ncol=1, bbox_to_anchor=[0.1,1.51,0.5,0.37], loc="center", fontsize=fsize)
ax[2,0].legend(ncol=2, loc=1, labelspacing=0.1, handletextpad=0.2,handlelength=0.2, columnspacing=0.2, fontsize=fsize, borderpad=0.1)
ax[4,0].set_xlabel("Eigenstate index", fontsize=fsize)
ax[4,1].set_xlabel(r"$d_0$", fontsize=fsize)
#ax[0].set_ylabel("$I_{-1}/I_{1}$")
#ax[1].set_ylabel("$I_{-2}/I_{-1}, I_{2}/I_{1}$")
#ax[2].set_ylabel("$arg(A_{-1,1})-arg(A_{0,0})$")
#ax[3].set_ylabel("$I^{(+)}_{1}/I^{(+)}_{0}$")
ax[1,0].set_ylabel("$S_1$", fontsize=fsize)
ax[2,0].set_ylabel("$S_2, S_2'$", fontsize=fsize, labelpad=-5)
ax[3,0].set_ylabel("$S_3$", fontsize=fsize)
ax[4,0].set_ylabel("$S_4$", fontsize=fsize)
ax[0,1].set_ylabel(r"$O(|\Psi_0 \rangle)$", fontsize=fsize)
ax[0,1].tick_params(axis='both', which='major', labelsize=fsize)
ax[1,0].tick_params(axis='both', which='major', labelsize=fsize)
ax[2,0].tick_params(axis='both', which='major', labelsize=fsize)
ax[3,0].tick_params(axis='both', which='major', labelsize=fsize)
ax[4,0].tick_params(axis='both', which='major', labelsize=fsize)
ax[4,1].tick_params(axis='both', which='major', labelsize=fsize)



#######################################################################################
## Signatures as a function of d0 (second column)
########################################################################################



d0s=[0.1, 0.2, 0.3, 0.4, 0.5]
#d0s=[0.3]
styles=["solid", "dashed", "dotted", "dashdot", (0, (3, 5, 1, 5, 1, 5))]
A=[]
B=[]
C=[]
D=[]
E=[]
overlaps=[]
for id0, d0 in enumerate(d0s):
  f= open("output_OBC_d0={:.4f}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_Lgs{}.json".format(d0, D0,Npart,B0,Vquad0, Lgs))
  data_dict=json.load(f)
  f.close()
  # ----------- get signatures based on single-photon emissions ---------
  sp=data_dict["single photon emission"]
  ls=sp["ls"]
  i0=ls.index(0)
  Nstate=len(sp["data"])
  state=sp["data"][0]
  A.append(state["occ_minus"][i0-1]/state["occ_minus"][i0+1]) #S1
  B.append(state["occ_minus"][i0-2]/state["occ_minus"][i0-1]) #S2
  C.append(state["occ_minus"][i0+2]/state["occ_minus"][i0+1]) #S2'
  E.append(state["occ_plus"][i0+1]/state["occ_plus"][i0])     #S4
  # ----------- get S3 ---------
  mp=data_dict["two-photon emission"]
  lpairs=mp["lpairs"]
  state=mp["data"][0]
  aph=np.array(state["amplitude phase"])
  aph=(aph+0.02) % (2*np.pi)-0.02
  D.append(aph[1])
  # get the overlap
  data=np.loadtxt("spectrum_OBC_d0={:.4f}_hermTrue_simpFalse_D0=0.6000_Npart=2_B0=12.0000_Vquad0=0.0000_rtrunc=-1.0000_mmax5.txt".format(d0))
  overlaps.append(data[20, 2])

# ------ plot signatures and overlap ------ 
ax[0,1].scatter(d0s, overlaps, c="k", marker="o")
ax[1,1].scatter(d0s, A, c="k", marker="o")
ax[2,1].scatter(d0s, B, c="k", marker="+")
ax[2,1].scatter(d0s, C, c="k", marker="x")
ax[2,1].set_yscale("log")
ax[3,1].scatter(d0s, D, c="k", marker="o")
ax[4,1].scatter(d0s, E, c="k", marker="o")

# ------- adjust parameters of the plot, add labels etc. ------- 

for i in range(4):
  ax[i+1,1].set_yticklabels([])

ax[3,0].set_yticks([0, np.pi])
ax[3,1].set_yticks([0, np.pi])
ax[3,0].set_yticklabels([0, r"$\pi$"], fontsize=fsize)

for i in range(5):
  ax[i,0].set_xticks([0,2,4,6,8,10])

for i in range(4):
  ax[i,0].set_xticklabels([])
  ax[i,1].set_xticklabels([])

for i in range(1,5):
  a=ax[i,0]
  a.text(-0.25, 0.85, "("+ascii_lowercase[i-1]+")", transform=a.transAxes, fontsize=fsize)
for i in range(1,5):
  a=ax[i,1]
  a.text(-0.14, 0.85, "("+ascii_lowercase[i+4]+")", transform=a.transAxes, fontsize=fsize)
a=ax[0,1]
a.text(-0.45, 0.85, "(e)", transform=a.transAxes, fontsize=fsize)

ax[0,0].set_axis_off()
plt.savefig("Figure6.pdf")

plt.show()
