import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.cm import ScalarMappable
#from string import ascii_lowercase
import scipy as sp
import json
import sys
import os
import inspect


# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import NonHermitianAngularMomentumED as ed
import OrdinaryBosons as ob

np.set_printoptions(suppress=True, linewidth=100000)


#------------------------- system parameters ------------------------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.2  # dimensionless lattice constant
d=d0*lambda0
D0=1.6 # system size 
D=D0*d
Vquad0=10 # rescaled confining potential strength (not useful for D0=0.6)
Npart=2 # maximum number of particles (note: some parts of the code work for two particles only)
Lmax=6 # Lmax-fold rotational symmetry
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field
w0=1 # waist parameter of the LG modes
Lgs=2 # will compute emission from this subspace
omega=0.1 # Rabi frequency in the units of gamma0
dip1=1 # dipole element of state -
dip2=1 # dipole element oi state +
#drive="gaussian"
drive="uniform" #type of drive (gaussian or uniform)
Ltarget=1 # single-particle angular momentum sector that is accessed by driving
pol=0 # polarization (0=-, 1=+)
w0=1 # w0 of the measurement Gaussian modes

B=B0/d**3/k0**3 # non-rescaled magnetic field
Vquad=Vquad0/d**5/k0**3 # non-rescaled harmonic potential

################################################################################
## prepare the variables needed for driven system calculation
###############################################################################

# set polarization as a vector
if(pol==0):
  pol2=[1,0]
if(pol==1):
  pol2=[0,1]


latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ] #lattice vectors
orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell (in the basis of lattice vectors)


# get site positions and plot them
xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d)
Nseg=len(xs)
Nsite=Nseg*Lmax
xall=[]
yall=[]
for i in range(Lmax):
  xres,yres=ed.RotateAll(xs,ys, i, Lmax)
  plt.scatter(xres,yres)
  xall.extend(xres)
  yall.extend(yres)
pos=np.transpose(np.array([xall,yall]))
Nsite=len(pos)
Ns=2*Nsite


# functions to evaluate the driving field and measurement modes. The evaluated matrix has shape Nsite sites x 2 polarizations.
def UniformFieldAll(pol, pos):
  # uniform field of strength 1 at every point
  Nsite=pos.shape[0]
  field=np.zeros((Nsite,2))
  for i in range(Nsite):
    field[i,:]=pol
  return field

def GaussianModeAll(w0, l, pol, pos):
  # normalized Gaussian mode
  Nsite=pos.shape[0]
  field=np.zeros((Nsite,2), dtype=complex)
  for i in range(Nsite):
    field[i,:]=GaussianMode(w0, l, pos[i,:])*np.array(pol)
  return field


def GaussianMode(w0, l, point):
  # evaluate a Gaussian mode at a single point in space
  x,y=point
  rho=np.sqrt(x**2+y**2)
  E0=np.sqrt(2/(np.math.factorial(abs(l))*np.pi))/w0
  phi=np.arctan2(y,x)
  return E0*(rho*np.sqrt(2)/w0)**np.abs(l)*np.exp(-rho**2/w0**2)*np.exp(-1j*l*phi)

# evaluate the driving field 
if(drive=="uniform"):
  field=UniformFieldAll(pol2, pos)
if(drive=="gaussian"):
  ldrive=Ltarget+(2*pol-1)
  field=GaussianModeAll(w0,ldrive,pol2, pos)
field*=omega # field pattern * field strength (Rabi frequency)


# generate bases for sectors with different particle numbers
bases=[]
nconfs=[]
for n in range(Npart+1):
  print("Generating basis for Npart=", n)
  basis=ed.GenerateBasis(n, Nsite)
  bases.append(basis)
  nconfs.append(len(basis))

# generate angular momentum bases
mombases=[]
for n in range(Npart+1):
  print("Generating angular momentum basis for Npart=", n) 
  mombasis = ed.FindAngularMomentumEigenstates(bases[n], n, Lmax)
  mombases.append(mombasis)


# generate non-Hermitian Hamiltonians in gauge-transformed bases (with angular momentum conservtion)
# only the subspace relevant for driving (L= Ltarget * n % Lmax) is included in the calculatin
interactions3=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax, False) # generate interactions with hermitian=False
H0s_angular=[[]]
for n in range(1,Npart+1):
  print("Generating H0 for Npart=", n)
  H0=ed.GenerateCRSHamiltonianWithAngularMomentum(Ltarget*n % Lmax,interactions3,bases[n],Lmax, mombases[n])
  H0s_angular.append(H0)


################################################################################
## prepare the variables for measurements during/after the calculation
###############################################################################

#-------------------------------------------------------------------------------------------------
# generate and diagonalize the Hermitian Hamiltonians in gauge-transformed basis (with angular momentum conservation)
# for the purpose of calculating overlaps of steady state with eigenstates
#-------------------------------------------------------------------------------------------------
interactions2=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax, True) # generate interaction with hermitian=True
evecssite_all=[]
evecssite_all2=[]
evecs_all=[]
evals_all=[]
nevals_all=[] # numbers of eigenvalues per momentum sector
nconf_max=100
neval=10
for n in range(1, Npart+1):
  print("Generating reference Hamiltonians for Npart=", n)
  L=n*Ltarget % Lmax
  Hherm=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions2,bases[n],Lmax, mombases[n])
  if(len(mombases[n][0][L])<nconf_max):
    Hdense=np.array(Hherm.todense())
    eval, evec=np.linalg.eig(Hdense)
  else:
    eval, evec=sp.sparse.linalg.eigsh(Hherm, k=min(len(mombases[n][0][L])-1,neval), return_eigenvectors=True, which="SR")
  nevals_all.append(len(eval))
  arg=np.argsort(np.real(eval))
  eval=eval[arg]
  evec=evec[:,arg]
  evals_all.append(eval)
  evecs_all.append(evec)
  evecsite=ed.ConvertEigenstates(evec, bases[n], mombases[n], L,Lmax)
  evecssite_all2.append(evecsite)
  evecsite2=ed.OriginalPhases(evecsite, bases[n], Lmax, Nsite)
  evecssite_all.append(evecsite2)
#---------------------------------------------------------------
# generate the mode occupation operators
#---------------------------------------------------------------
Nsector=Nsite//Lmax
outmode_hams=[[],[]]
outmode_l0s=list(range(-6,7))
noutmodes=len(outmode_l0s)
for ipol, pol0 in enumerate([[1,0],[0,1]]):
  for l0 in outmode_l0s:
    mode=GaussianModeAll(w0, l0, pol0, pos)
    int_mode=[]
    outmode_hams0=[]
    for i in range(Nsite):
      L1=i//Nsector
      phases1=[np.exp(-2j*np.pi*L1/Lmax), np.exp(2j*np.pi*L1/Lmax)]
      for j in range(Nsite):
        L2=j//Nsector
        phases2=[np.exp(-2j*np.pi*L2/Lmax), np.exp(2j*np.pi*L2/Lmax)]
        if(i!=j):
          for dir1 in range(2):
            for dir2 in range(2):
              #int_mode.append(["hopping", i, j, dir1,dir2, mode[i,dir1]*np.conj(mode[j,dir2]) *np.conj(phases1[dir1])*phases2[dir2]])
              int_mode.append(["hopping", i, j, dir1,dir2, mode[i,dir1]*np.conj(mode[j,dir2])])
        else:
          for dir1 in range(2):
            int_mode.append(["onsite", i, dir1, np.abs(mode[i,dir1]**2)])
    for n in range(1,Npart+1):
      print("Generate mode occupation operators for l0=", l0, "Npart=", n)
      ham0=ed.GenerateSparseHamiltonianNoMomentumCoLex(int_mode,bases[n], Nsite, n)
      outmode_hams0.append(ham0)
    outmode_hams[ipol].append(outmode_hams0)

#----------------------------------------------------
# generate the orbital occupation operators
#----------------------------------------------------
orb_hams=[]
orbs=[]
evecssiteSP2=np.zeros((Ns,Ns), dtype=complex)
LsSP_all=[]
numbersSP_all=[]
k=0
for lorb in range(Lmax):
  print("L:", L)
  H1part=ed.GenerateCRSHamiltonianWithAngularMomentum(lorb,interactions2,bases[1],Lmax, mombases[1])
  H1part=np.array(H1part.todense())
  E,V=np.linalg.eig(H1part)
  Vsite=ed.ConvertEigenstates(V, bases[1], mombases[1], lorb,Lmax)
  LsSP_all.extend(np.ones(len(E))*lorb)
  numbersSP_all.extend(np.arange(len(E)))
  for i in range(len(E)):
    evecssiteSP2[:,k]=Vsite[:,i]
    k+=1
  for iE in range(len(E)):
    gs=Vsite[:,iE]
    orbs.append(gs)
    orb_int=[]
    for i in range(Nsite):
      L1=i//Nsector
      phases1=[np.exp(-2j*np.pi*L1/Lmax), np.exp(2j*np.pi*L1/Lmax)]
      for j in range(Nsite):
        L2=j//Nsector
        phases2=[np.exp(-2j*np.pi*L2/Lmax), np.exp(2j*np.pi*L2/Lmax)]
        if(i!=j):
          for dir1 in range(2):
            for dir2 in range(2):
              orb_int.append(["hopping", i, j, dir1,dir2, gs[2*i+dir1]*np.conj(gs[2*j+dir2])])
        else:
          for dir1 in range(2):
            for dir2 in range(2):
              if(dir1==dir2):
                orb_int.append(["onsite", i, dir1, np.abs(gs[2*i+dir1]**2)])
              else:
                orb_int.append(["hopping", i, i, dir1,dir2, gs[2*i+dir1]*np.conj(gs[2*i+dir2])])
    orb_hams0=[]
    for n in range(1,Npart+1):
      print("Generate orbital occupation operators for L=", lorb,"i=", iE, "Npart=", n)
      ham0=ed.GenerateSparseHamiltonianNoMomentumCoLex(orb_int,bases[n], Nsite, n)
      orb_hams0.append(ham0)
    orb_hams.append({"L":lorb, "iorb":iE, "H":orb_hams0})


#------------------------------------------------------------------------------
# generate the orbital bases and transformation matrices
#------------------------------------------------------------------------------
mombasesOB=[]
for n in range(2,Npart+1):
  pas=ob.PascalTriangle(n, Ns)
  Nconf=pas[n, Ns]
  U=ob.TwoOrbHCToOrdinaryBosonsU(bases[n], pas)
  mombasisOB,numlistOB=ob.BasisWith1DMomentum(n, Ns, pas, LsSP_all, Lmax)
  #U2=ob.ChangeBasis(evecssiteSP2,2,Ns,pas) 
  Lgs=(n*Ltarget)%Lmax
  U3=ob.ChangeBasisWith1DMomentum(evecssiteSP2, Lgs , mombasisOB, n,Ns,  pas)
  Umom=U3 @ U
  basisLN=[]
  for conf in mombasisOB[Lgs]:
    #print(conf)
    conf2=[]
    for c in conf:
      conf2.append([int(LsSP_all[c]), int(numbersSP_all[c])])
    basisLN.append(conf2)
  mombasesOB.append({"basis":mombasisOB[Lgs], "basisLN":basisLN, "U": Umom})
  print("Generate orbital basis for Npart=", n)

#------------------------------------------------------------------------------
# generate the two-photon intensity and amplitude operators
#------------------------------------------------------------------------------

if(Ltarget==0):
  lpairs=[
  [-1,-1],
  [-2,0],
  [-3,1],
  [-4,2],
  [-5,3],
  [-6,4],
  [-2,4],
  [-3,5],
  ]

if(Ltarget==1):
  lpairs=[
  [0,0],
  [-1,1],
  [-2,2],
  [-3,3],
  [-4,4],
  [-5,5],
  [-6,6],
  [0,6],
  [-1,5],
  [-2,4],
  [0,-6],
  ]
if(Ltarget==2): 
  lpairs=[
  [1,1],
  [0,2],
  [-1,3],
  [-2,4],
  [-3,5],
  [-4,6],
  [-4,0],
  #[0,-6],
  ]

twopart_vectors=[]
for lpair in lpairs:
  l1,l2=lpair
  mode1=GaussianModeAll(w0, l1, [1,0], pos)
  mode2=GaussianModeAll(w0, l2, [1,0], pos)
  orb1=mode1.flatten()
  orb2=mode2.flatten()
  phi=ed.TwoParticleAmplitudeVector2part(orb1, orb2, basis)
  twopart_vectors.append(phi)
  print("Generate two-particle amplitude operators for Npart=", n)




#################################################################################################
#   Solve the driven sustems
#################################################################################################


deltas=np.linspace(-3*B, 3*B, 200)
ndeltas=len(deltas)

# the function for solving for the weakly driven state at Npart particles given that we know the state psi0 at Npart-1
# earlier version, no momentum conservation, no rotation-invariance-restoring transformation
def GetDrivenState(basis, Npart, H0, psi0, dips, field):
  Nconf=len(basis)
  H1=H0-sp.sparse.identity(Nconf,  format="csr")*delta*Npart
  omegas=[]
  for conf in basis:
    temp=0
    tempcoeffs=[]
    tempcoeffs2=[]
    for ic, c in enumerate(conf):
      if(c>0):
        conf2=np.array(conf)
        conf2[ic]=0
        ind=ed.FindIndex(conf2, Npart-1)
        temp+=dips[c-1]*field[ic, c-1]*psi0[ind]
        tempcoeffs.append(psi0[ind] )
        tempcoeffs2.append(dips[c-1]*field[ic, c-1]*psi0[ind])
    omegas.append(temp)
  print("solving the linear system of size", H1.shape)
  sol=sp.sparse.linalg.spsolve(H1, omegas)
  return sol

# the function for solving for the weakly driven state at Npart particles given that we know the state psi0 at Npart-1
# newer version, with momentum conservation, with rotation-invariance-restoring transformation
def GetDrivenStateAngular(bases,mombases,Ltarget, Npart, H0, psi0, dips, field):
  basis_start=bases[Npart] # site basis for Npart particles
  basis_end=bases[Npart-1] # site basis for Npart-1 particles
  OrbitsPerMomentum_start,lin3_start,phase_start, norm_start=mombases[Npart] # anggular momentum basis for Npart particles
  OrbitsPerMomentum_end,lin3_end,phase_end, norm_end=mombases[Npart-1] # angular momentum basis for Npart-1 particles
  Lstart=Ltarget*Npart % Lmax # angular momentum of the relevant symmetry sector for Npart particles
  Lend=Ltarget*(Npart-1) % Lmax # angular momentum of the relevant symmetry sector for Npart particles
  Nsite=len(basis_start[0]) #number of sites
  Nsector=Nsite//Lmax #number of sites in the "sector 0" of the flake
  Nconf=len(OrbitsPerMomentum_start[Lstart]) #size of the Hilbert space for Npart particles
  H1=H0-sp.sparse.identity(Nconf,  format="csr")*delta*Npart # add detuning term to the Hamiltonian
  omegas=[]
  for iorbit, orbit in enumerate(OrbitsPerMomentum_start[Lstart]):     # iterate over angular momentum basis states
    hrow=np.zeros(len(OrbitsPerMomentum_start[Lstart]), dtype=complex)
    temp=0
    for i, o in enumerate(orbit): #iterate over the site basis states that form the linear combination being the angular momentum basis state
      norm1=len(orbit) 
      phase1=np.exp(-2j*np.pi*float(phase_start[Lstart,o])/float(Lmax)) # phase of the site basis state in the linear combination (they all enter with the same amplitude 1/sqrt(norm1))
      conf=basis_start[o] #get the site basis state in the occupation-number-like notation (0= atomic ground state, 1= - orbital, 2 = + orbital)
      for ic, c in enumerate(conf): # iterate over sites
        if(c>0):
          conf2=np.array(conf) 
          conf2[ic]=0 # form a site basis state of Npart-1 particles by annihilating one particle at site ic
          onew=ed.FindIndex(conf2, Npart-1) # find the index of conf2 in the site basis
          iorbitnew=lin3_end[Lend,onew] # find which momentum basis state does conf2 belong to
          phase2=np.exp(2j*np.pi*float(phase_end[Lend,onew])/float(Lmax)) #find the phase of conf2 in the linear combination making up the momentum basis state
          norm2=norm_end[Lend, onew] #find the norm of conf2 in the linear combination making up the momentum basis state
          L1=ic//Nsector # find which sector of the flake does site ic belong to
          signs=[-1,1]
          phase3=np.exp(signs[c-1]*2j*np.pi*L1/float(Lmax)) # phase related to the rotational-invariance-restoring transformation
          pref=phase1*phase2/np.sqrt(norm1*norm2)*phase3
          temp+=dips[c-1]*field[ic, c-1]*psi0[iorbitnew]*pref # last term of eq. (30) in the paper 
    omegas.append(temp)
  print("solving the linear system of size", H1.shape)
  sol=sp.sparse.linalg.spsolve(H1, omegas) # solve eq. (30) from the paper
  return sol
   

norms_all=[[] for _ in range(Npart+1)] #list of lists of |c_n|^2


#create empty arrays to fill with the overlaps, output mode populations and orbital populations
soloverlaps_all=[]
outmode_pops_all=[]
orb_pops_all=[]
for n in range(Npart):
  soloverlaps_all.append(np.zeros((nevals_all[n],ndeltas)))
  outmode_pops_all.append(np.zeros((2,noutmodes, ndeltas)))
  orb_pops_all.append(np.zeros((len(orb_hams), ndeltas)))

twopart_amps=np.zeros((len(lpairs), ndeltas), dtype=complex) 
twopart_coeffs=np.zeros((len(mombasesOB[0]["basis"]), ndeltas), dtype=complex)
for idelta, delta in enumerate(deltas):
  print("running calculations for delta=", delta)
  sols=[[1]] # wavefunction in no-excitation sector
  sols_angular=[[1]] # wavefunction in no-excitation sector
  for n in range(1,Npart+1):
    print("n=", n)
    sol0=GetDrivenStateAngular(bases,mombases,Ltarget, n, H0s_angular[n], sols_angular[n-1], [dip1,dip2], field) # get the n-particle component of the steady state
    sol0site=ed.ConvertEigenstates(np.transpose([sol0]), bases[n], mombases[n], Ltarget*n % Lmax,Lmax) # transform it to site basis
    sol0site2=ed.OriginalPhases(sol0site, bases[n], Lmax, Nsite) # undo the rotational-invariance-restoring transformation
    print("solution obtained")
    sol0site=sol0site[:,0]
    norm0=np.dot(np.conj(sol0site), sol0site) 
    sol0sitenorm=sol0site/np.sqrt(norm0) #normalize n-particle component of the steady state in site basis
    norm=np.dot(np.conj(sol0), sol0)
    solnorm=sol0/np.sqrt(norm) #normalize n-particle component of the steady state in the angular momentum basis
    sols_angular.append(sol0) # store the n-particle component of the steady state
    norms_all[n].append(norm) # store the |c_n|^2 coeff
    print("calculating overlaps")
    for i in range(nevals_all[n-1]):
        ov=np.abs(np.dot(np.conj(solnorm), evecs_all[n-1][:,i]))**2
        soloverlaps_all[n-1][i,idelta]=ov
    print("calculating out mode populations")
    for ipol in range(2):
      for ioutmode, l0 in enumerate(outmode_l0s):
        pop=np.transpose(np.conj(sol0site2)) @ outmode_hams[ipol][ioutmode][n-1] @ sol0site2
        outmode_pops_all[n-1][ipol,ioutmode, idelta]=np.real(pop)
    print("calculating orbital populations")
    for iorb,orbham in enumerate(orb_hams):
      pop=np.transpose(np.conj(sol0sitenorm)) @ orbham["H"][n-1] @ sol0sitenorm
      orb_pops_all[n-1][iorb, idelta]=np.real(pop)
    print("calculating two-photon amplitudes")
    if(n==2):
      for ipair, vec in enumerate(twopart_vectors):
        twopart_amps[ipair, idelta]=np.conj(vec) @ sol0site2
    print("calculating two-particle wavefunction coefficients in the orbital basis")
    if(n==2):
      coeffs=mombasesOB[n-2]["U"] @ sol0sitenorm
      twopart_coeffs[:, idelta]=coeffs

###########################################################################################################################################
# plot results
###########################################################################################################################################

fig, ax=plt.subplots(1, Npart, squeeze=False)
for n in range(Npart):
  ax[0,n].plot(deltas, norms_all[n+1])
  ax[0,n].set_ylabel(r"$|c_{}|^2$".format(n+1))
  ax[0,n].set_xlabel(r"$\Delta$")
plt.suptitle(r"Probability of having exactly $n$ particles")
plt.tight_layout()

fig, ax=plt.subplots(1, Npart, squeeze=False)
for n in range(Npart):
  for i in range(soloverlaps_all[n].shape[0]):
      ax[0, n].plot(deltas, soloverlaps_all[n][i,:], label="state {}".format(i))
  ax[0,n].legend()
  ax[0,n].set_title("{} particles".format(n+1))
  ax[0,n].set_ylabel(r"$|\langle \psi^{(n)}_{steady} |\psi_i \rangle|^2$")
  ax[0,n].set_xlabel(r"$\Delta$")
plt.suptitle(r"Overlaps with eigenstates of Hermitian Hamiltonian")
plt.tight_layout()


# find the location of maximum overlap with a given eigenstate
maxovloc=[]
for n in range(Npart):
  maxovloc0=[]  
  for i in range(soloverlaps_all[n].shape[0]):
    ind=np.argmax(soloverlaps_all[n][i,:])
    maxovloc0.append(deltas[ind])
  maxovloc.append(maxovloc0)


for ipol in range(2):
  fig, ax=plt.subplots(1, Npart, squeeze=False)
  for n in range(Npart):
    for i, l0 in enumerate(outmode_l0s):
      ax[0, n].plot(deltas, outmode_pops_all[n][ipol,i,:], label=r"$l={}$".format(l0))
    ax[0,n].legend()
    ax[0,n].set_title("{} particles".format(n+1)) 
    ax[0,n].set_ylabel(r"$I^{}_l$".format(["-", "+"][ipol]))
    ax[0,n].set_xlabel(r"$\Delta$")
  plt.suptitle(r"Intensity of {} polarized emitted light".format(["-", "+"][ipol]))
  plt.tight_layout()



fig, ax=plt.subplots(1, Npart, squeeze=False)
linestyles=["solid", "dotted", "dashed", "dashdot", (0, (1, 10)), (0, (1, 1)), (5, (10, 3)), (0, (5, 10)), (0, (5, 1))]
for n in range(Npart):
  for iorb, orb in enumerate(orb_hams):
    c="C{}".format(orb["L"]) 
    if(orb["iorb"]<len(linestyles)):
      ls=linestyles[orb["iorb"]]
    ax[0, n].plot(deltas, orb_pops_all[n][iorb,:], linestyle=ls, color=c, label=r"$\phi^{}_{}$".format(orb["iorb"], orb["L"]))
  ax[0,n].set_title("{} particles".format(n+1))
  ax[0,n].set_ylabel(r"$\langle n^{i}_l\rangle$")
  ax[0,n].set_xlabel(r"$\Delta$")
ax[0,1].legend()
plt.suptitle(r"Orbital populations")
plt.tight_layout()



fig, ax=plt.subplots(len(lpairs), 2,width_ratios=[50,1],squeeze=False, sharex="col")
plt.subplots_adjust(left=0.2)
for ipair, lpair in enumerate(lpairs):
  phases=np.angle(twopart_amps[ipair,:]/twopart_amps[0,:])
  phases=phases % (2*np.pi) / (2*np.pi)
  my_cmap = mpl.colormaps['twilight_shifted']
  colors = my_cmap(phases)
  im=ax[ipair,0].bar(deltas, np.abs(twopart_amps[ipair,:]), width=deltas[1]-deltas[0], color=colors)
  ax[ipair,0].plot(deltas, np.abs(twopart_amps[ipair,:]), c="r")
  ax[ipair,0].set_ylabel(r"$|A_{{{}}}|$".format("{},{}".format(lpair[0], lpair[1])), rotation="horizontal", verticalalignment="center", horizontalalignment="right")
  ax[ipair,0].set_ylim(0,np.max(np.abs(twopart_amps[ipair,:])))
  ax[ipair,1].set_axis_off()
  for loc in maxovloc[1]:
    ax[ipair,0].axvline(loc) #vertical lines mean maxima of the overlap with each two-particle eigenstate
ax[len(lpairs)-1,0].set_xlabel(r"$\Delta$")
sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(0,2*np.pi))
sm.set_array([])
ax[0,1].set_axis_on()
cbar = plt.colorbar(sm, cax=ax[0,1], label=r"$\arg(A_{l_1,l_2})$")
plt.suptitle(r"Two-photon amplitudes")

fig, ax=plt.subplots(len(mombasesOB[0]["basis"]), 2,width_ratios=[50,1], squeeze=False, sharey="col", sharex="col")
for i in range(len(mombasesOB[0]["basis"])):
  phases=np.angle(twopart_coeffs[i,:]/twopart_coeffs[0,:])
  phases=phases % (2*np.pi) / (2*np.pi)
  my_cmap = mpl.colormaps['twilight_shifted']
  colors = my_cmap(phases)
  ax[i,0].plot(deltas, np.abs(twopart_coeffs[i,:])**2)
  ax[i,0].bar(deltas, np.abs(twopart_coeffs[i,:])**2, width=deltas[1]-deltas[0], color=colors)
  bln=mombasesOB[0]["basisLN"][i]
  ax[i,0].set_ylabel(r"$|\phi^{}_{} \phi^{}_{}  \rangle$".format(bln[0][1], bln[0][0], bln[1][1], bln[1][0]), rotation="horizontal", verticalalignment="center", horizontalalignment="right") 
  ax[i,1].set_axis_off()
  for loc in maxovloc[1]:
    ax[i,0].axvline(loc)
ax[len(mombasesOB[0]["basis"])-1,0].set_xlabel(r"$\Delta$")
sm = ScalarMappable(cmap=my_cmap, norm=plt.Normalize(0,2*np.pi))
sm.set_array([])
ax[0,1].set_axis_on()
cbar = plt.colorbar(sm, cax=ax[0,1], label="phase")
plt.suptitle(r"Two-particle wfn coeffs in orbital basis")


###########################################################################################################################################
# store all data in a JSON file
###########################################################################################################################################
data={"deltas": list(deltas)}
subspace_results=[]
for n in range(1, Npart+1):
  states=[]
  for i in range(nevals_all[n-1]):
    states.append({"L": n*Ltarget % Lmax, "i": i, "E":np.real(evals_all[n-1][i]), "overlaps": list(np.real(soloverlaps_all[n-1][i,:])) })
  intensity=[]
  for ip, p in enumerate([[1,0], [0,1]]):
    for il,l in enumerate(outmode_l0s):
      intensity.append({"l":l,"pol":p, "intensity":list(outmode_pops_all[n-1][ip,il,:])})
  subs={"Npart": n, "nexc": list(np.real(norms_all[n])), "states": states, "single-photon intensity":intensity}
  if(n==2):
    subs["two-photon amplitudes"]=[]
    for ipair, lpair in enumerate(lpairs):
      subs["two-photon amplitudes"].append({"lpair":lpair, "real": list(np.real(twopart_amps[ipair,:])), "imag": list(np.imag(twopart_amps[ipair,:]))})
    subs["two-particle orbital coeffs"]=[]
    for ipair, opair in enumerate(mombasesOB[0]["basisLN"]):
      subs["two-particle orbital coeffs"].append({"conf":opair, "real": list(np.real(twopart_coeffs[ipair,:])), "imag": list(np.imag(twopart_coeffs[ipair,:]))})
  subspace_results.append(subs)
data["subspace_results"]=subspace_results

if(drive=="gaussian"):
  drivestr="g_l={}_w={:.2f}".format(ldrive,w0)
if(drive=="uniform"):
  drivestr="u"
fname="driving_ang_pol_"+drivestr+"_omega={:.4f}_pol{:.4f}_{:.4f}_Ltarget={}_d0={:.4f}_D0={:.4f}_B0={:.4f}_Vquad0={:.4f}.json".format(omega, pol2[0], pol2[1],Ltarget, d0,D0,B0,Vquad0)
f=open(fname, "w")
json.dump(data, f, indent=2)
f.close()

plt.show()
