import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
import sys
import os
import inspect
import math
import scipy as sp
from string import ascii_lowercase


# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import NonHermitianAngularMomentumED as ed

# set plot fonts
plt.rcParams['font.family'] = 'serif' 
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'


##################################################################################################
##  Start calculations
##################################################################################################

# ---------------- system parameters ---------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.3  # dimensionless lattice constant
d=d0*lambda0
hermitian=True # if diagonalize the Hermitian part of the Hamiltonian (keep True otherwise the orbitals will be nonorthogonal) 
D0=0.6 # system size 
D=D0*d
simplify=False #use simplified model instead of the full model
Vquad0=0 # rescaled confining potential strength (not useful for D0=0.6)
Lmax=6 # Lmax-fold rotational symmetry
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field
fsize=16 # plot font size

B=B0/d**3/k0**3 # non-rescaled magnetic field
Vquad=Vquad0/d**5/k0**3 #non-rescaled confining potential strength
latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ]
V=0
q=2

orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell

# ----------- get lattice positions (and plot them) ----------
xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d) # positions of sites in sector 0
Nseg=len(xs)
Nsite=Nseg*Lmax
xall=[]
yall=[]
plt.figure()
for i in range(Lmax):
  xres,yres=ed.RotateAll(xs,ys, i, Lmax) # rotate sector 0 to get other sectors
  plt.scatter(xres, yres) # plot site positions in a given sector
  xall.extend(xres)
  yall.extend(yres)
pos=np.transpose(np.array([xall,yall]))

ax=plt.gca()
ax.set_aspect("equal")


# generate interactions (with rotation-invariance-restoring transform)
if(simplify):
  interactions=ed.RealDipolarInteractionFiniteCircularSimplifiedTransform(pos, B, Vquad,k0, rtrunc, Lmax)
else:
  interactions=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax,hermitian)



################################################################################
##  Find single-particle eigenstates
################################################################################

basisSP=ed.GenerateBasis(1, Nsite) # single-particle site basis
mombasisSP = ed.FindAngularMomentumEigenstates(basisSP, 1, Lmax) # single-particle angular momentum basis


evalsSPL=[]
evecsSPL=[]
evecssiteSPL=[]
neval=6
evalsSP_all=[]
LsSP_all=[]
for L in range(Lmax):
  H=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basisSP,Lmax, mombasisSP)
  Hdense=H.todense()
  eval, evec=np.linalg.eig(Hdense) #diagonalize
  arg=np.argsort(np.real(eval)) # sort eigenvalues
  eval=eval[arg]
  evec=evec[:,arg]
  evec=np.array(evec)
  evalsSP_all.extend(eval)
  LsSP_all.extend(np.ones(len(eval))*L)
  evalsSPL.append(eval)
  evecsSPL.append(evec)
  evecsite=ed.ConvertEigenstates(evec, basisSP, mombasisSP, L,Lmax) # convert eigenstates to site basis
  evecsite=ed.OriginalPhases(evecsite, basisSP, Lmax, Nsite) # undo the rotation-invariance-restoring transformation
  evecssiteSPL.append(evecsite)

# plot single-particle spectrum (just for control, not included in the figure)
plt.figure()
plt.scatter(LsSP_all, np.real(evalsSP_all))
plt.colorbar()


################################################################################
##  Find mode projections of light emitted by eigenstates
################################################################################

def LaguerreGaussModeBasicAtZero(w, l,p, pos):
  # evaluate Laguerre-Gauss mode of waist w, parameters l,p at position pos, z=0
  x,y=pos
  r=np.sqrt(x**2+y**2)
  phi=np.arctan2(y,x)
  C=np.sqrt(2*math.factorial(p)/(np.pi*math.factorial(p+np.abs(l))))
  r0=r/w
  po=(np.sqrt(2)*r0)**np.abs(l)
  lag=sp.special.genlaguerre(p,np.abs(l))(2*r0**2)
  ex=np.exp(-r0**2)
  ang=np.exp(-1j*l*phi)
  return C*po*lag*ex*ang/w


ls=np.arange(-6,7) # the values of l used in the computation
w=1 # LG mode waist
p=0 # LG mode p

neval=2 # number of eigenvalues per momentum sector to include
fig,ax =plt.subplots(neval, 2, figsize=(6, 4.5), sharey=True, sharex=True)
projections_plus_all=np.zeros((neval, Lmax,len(ls)) )
projections_minus_all=np.zeros((neval, Lmax,len(ls)) )
for ieval in range(neval):
  for L in range(6):
    state=evecssiteSPL[L][:,ieval] # orbital to compute the emission from
    projections_plus=[] # list of + polarization mode projections
    projections_minus=[] # list of 0 polarization mode projections
    for il,l in enumerate(ls):
      E1=0
      E2=0
      #print("l=", l)
      for i in range(Nsite):
        i1=2*i # minus atomic orbital index in the basis
        i2=2*i+1 # plus atomic orbital index in the basis
        E1+=state[i1]*np.conj(LaguerreGaussModeBasicAtZero(w, l,p, pos[i,:])) #evaluate the mode projection
        E2+=state[i2]*np.conj(LaguerreGaussModeBasicAtZero(w, l,p, pos[i,:])) #evaluate the mode projection
      projections_minus.append(np.abs(E1)**2)
      projections_plus.append(np.abs(E2)**2)
    # plot the mode projection
    ax[ieval,0].plot(ls, projections_minus, label=r"$L={}$".format(L), marker="o")
    ax[ieval,1].plot(ls, projections_plus, marker="o")
  # plot labels etc.
  ax[ieval,0].set_title(r"$-$ polarization, state #{}".format(ieval), fontsize=fsize)
  ax[ieval,1].set_title(r"$+$ polarization, state #{}".format(ieval), fontsize=fsize)
  ax[ieval,0].set_yscale("log")
  ax[ieval,0].set_ylabel(r"$I^{\sigma}_l$", fontsize=fsize, labelpad=-5)
  ax[ieval,0].tick_params(axis='both', which='major', labelsize=fsize)
  ax[ieval,1].tick_params(axis='both', which='major', labelsize=fsize)
ax[neval-1,0].set_xlabel(r"$l$", fontsize=fsize)
ax[neval-1,1].set_xlabel(r"$l$", fontsize=fsize)
ax[0,0].set_ylim([10**(-13), 10])

# subplot labels
k=0
for i in range(2):
  for j in range(2):
    ax[i,j].text(0.02, 0.8, "("+ascii_lowercase[k]+")", transform=ax[i,j].transAxes, fontsize=fsize)
    k+=1



plt.tight_layout(rect=[0,0.15, 1, 1])
#plt.tight_layout()

# add legend
ax[0,0].legend(ncols=3,loc="center", bbox_to_anchor=[-0, -2.4, 2, 0.1], fontsize=fsize)

plt.savefig("Figure4.pdf")

plt.show()











