import numpy as np
import matplotlib.pyplot as plt
import sys
import itertools as itools
import math
import os
import inspect
from colorsys import hls_to_rgb
import scipy as sp
from string import ascii_lowercase

# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import NonHermitianAngularMomentumED as ed


##################################################################################################
##  Start calculations
##################################################################################################

# ---------------- system parameters ---------
lambda0=1 #atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.2  # dimensionless lattice constant
d=d0*lambda0
hermitian=True # if diagonalize the Hermitian part of the Hamiltonian (keep True otherwise the orbitals will be nonorthogonal) 
D0=1.6 # system size 
D=D0*d
simplify=False #use simplified model instead of the full model
Vquad0=10 # rescaled confining potential strength (not useful for D0=0.6)
Lmax=6 # Lmax-fold rotational symmetry
rtrunc=-1 #truncation radius of interaction (-1 if no truncation)
B0=12 # rescaled magnetic field

B=B0/d**3/k0**3 # non-rescaled magnetic field
Vquad=Vquad0/d**5/k0**3 #non-rescaled confining potential strength
latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ]
V=0
q=2

orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell

# ----------- get lattice positions  ----------
xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d) # positions of sites in sector 0
Nseg=len(xs)
Nsite=Nseg*Lmax
xall=[]
yall=[]
for i in range(Lmax):
  xres,yres=ed.RotateAll(xs,ys, i, Lmax) # rotate sector 0 to get other sectors
  xall.extend(xres)
  yall.extend(yres)
pos=np.transpose(np.array([xall,yall]))



# generate interactions (with rotation-invariance-restoring transform)
if(simplify):
  interactions=ed.RealDipolarInteractionFiniteCircularSimplifiedTransform(pos, B, Vquad,k0, rtrunc, Lmax)
else:
  interactions=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax,hermitian)



################################################################################
##  Find single-particle eigenstates
################################################################################

basisSP=ed.GenerateBasis(1, Nsite) # single-particle site basis
mombasisSP = ed.FindAngularMomentumEigenstates(basisSP, 1, Lmax) # single-particle angular momentum basis


evalsSPL=[]
evecsSPL=[]
evecssiteSPL=[]
neval=6
evalsSP_all=[]
LsSP_all=[]
for L in range(Lmax):
  H=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions,basisSP,Lmax, mombasisSP)
  Hdense=H.todense()
  eval, evec=np.linalg.eig(Hdense) #diagonalize
  arg=np.argsort(np.real(eval)) # sort eigenvalues
  eval=eval[arg]
  evec=evec[:,arg]
  evec=np.array(evec)
  evalsSP_all.extend(eval)
  LsSP_all.extend(np.ones(len(eval))*L)
  evalsSPL.append(eval)
  evecsSPL.append(evec)
  evecsite=ed.ConvertEigenstates(evec, basisSP, mombasisSP, L,Lmax) # convert eigenstates to site basis
  evecsite=ed.OriginalPhases(evecsite, basisSP, Lmax, Nsite) # undo the rotation-invariance-restoring transformation
  evecssiteSPL.append(evecsite)

# plot single-particle spectrum (just for control, not included in the figure)
plt.figure()
plt.scatter(LsSP_all, np.real(evalsSP_all))


################################################################################
##  Find mode projections of light emitted by eigenstates
################################################################################


def LaguerreGaussModeBasicAtZero(w, l,p, pos):
  x,y=pos
  r=np.sqrt(x**2+y**2)
  phi=np.arctan2(y,x)
  C=np.sqrt(2*math.factorial(p)/(np.pi*math.factorial(p+np.abs(l))))
  r0=r/w
  po=(np.sqrt(2)*r0)**np.abs(l)
  lag=sp.special.genlaguerre(p,np.abs(l))(2*r0**2)
  ex=np.exp(-r0**2)
  ang=np.exp(-1j*l*phi)
  #print(phi, ang)
  return C*po*lag*ex*ang/w


ls=np.arange(-6,7) # l values of the LG modes to calculate mode projection into
w=1
p=0

#print(len(state), Nsite)

neval=2 # number of orbitals per momentum subspace to include in the calculation 
fig,ax =plt.subplots(1,2*neval, figsize=(12, 4), sharey=True)
projections_plus_all=np.zeros((neval, Lmax,len(ls)) )
projections_minus_all=np.zeros((neval, Lmax,len(ls)) )
for ieval in range(neval):
  for L in range(6):
    state=evecssiteSPL[L][:,ieval] # the current orbital
    projections_plus=[] # list of + polarization mode projections
    projections_minus=[]# list of - polarization mode projections
    for il,l in enumerate(ls):
      E1=0
      E2=0
      for i in range(Nsite):
        i1=2*i  # - atomic orbital index
        i2=2*i+1 # + atomic orbital index
        pos1=pos[i,:]
        E1+=state[i1]*np.conj(LaguerreGaussModeBasicAtZero(w, l,p, pos1)) # evaluate mode projection for - polarization
        E2+=state[i2]*np.conj(LaguerreGaussModeBasicAtZero(w, l,p, pos1)) # evaluate mode projection for + polarization
      projections_minus.append(np.abs(E1)**2)
      projections_plus.append(np.abs(E2)**2)
      projections_minus_all[ieval, L, il]=np.abs(E1)**2
      projections_plus_all[ieval, L, il]=np.abs(E2)**2
    #plot mode projections (for control, not included in the figure)
    ax[2*ieval].plot(ls, projections_minus, label="L={}".format(L))
    ax[2*ieval+1].plot(ls, projections_plus)
  # plot labels
  ax[2*ieval].set_title("- polarization, state #{}".format(ieval+1))
  ax[2*ieval+1].set_title("+ polarization, state #{}".format(ieval+1))
  ax[2*ieval].set_xlabel(r"$l$")
  ax[2*ieval+1].set_xlabel(r"$l$")
ax[0].legend()
ax[0].set_ylabel(r"$\langle\tilde{\phi}_i |\hat{E}_{l}^{*} \hat{E}_l |\tilde{\phi}_i \rangle$")
plt.tight_layout()
for i in range(2*neval):
  ax[i].text(0, 1.02, "("+ascii_lowercase[i]+")", transform=ax[i].transAxes)



###############################################################################################3
#         plot the system on the top of an example mode (for control, not included in the figure) 
###############################################################################################
example_l=2
zdist=0
npointsX=30 # number of points on a grid to evaluate the mode
npointsY=30
windowX=1
windowY=1
xv=np.linspace(-windowX, windowX, npointsX, endpoint=True)
yv=np.linspace(-windowY, windowY, npointsY, endpoint=True)
xmat, ymat=np.meshgrid(xv, yv)
light=np.zeros((npointsX,npointsY), dtype=complex )

# ---------- evaluate the mode ---------------
for i,x in enumerate(xv):
  for j,y in enumerate(yv):
    pos1=np.array([x,y])
    light[i,j]=LaguerreGaussModeBasicAtZero(w,example_l,p, pos1)

def colorize(z, rmax):
    # plot the complex mode function with phase as color
    r = np.abs(z)
    arg = np.angle(z) 
    h = (arg + np.pi)  / (2 * np.pi) + 0.5
    l = r/rmax
    s = 0.8
    c = np.vectorize(hls_to_rgb) (h,l,s) # --> tuple
    c = np.array(c)  # -->  array of (3,n,m) shape, but need (n,m,3)
    c = c.swapaxes(0,2) 
    return c

xstep=xv[1]-xv[0]
ystep=yv[1]-yv[0]
xmin=xv[0]-xstep/2
xmax=xv[-1]+xstep/2
ymin=yv[0]-ystep/2
ymax=yv[-1]+ystep/2
plt.figure()
vmax=np.max(np.abs(light))
im1=plt.imshow(colorize(light[:,:], vmax), origin="lower", extent=[xmin, xmax, ymin, ymax], vmin=0, vmax=vmax)
plt.scatter(xall,yall, c="k")


############################################################################################################################
#   Save the mode projections to a file
###############################################################################################################################

projmax_plus=np.max(projections_plus_all, axis=1)
projmax_minus=np.max(projections_minus_all, axis=1)

out=[ls]
for ieval in range(neval):
  out.append(projmax_minus[ieval,:])
  out.append(projmax_plus[ieval,:])

np.savetxt("modeprojections_OBC_d0={:.4f}_herm{}_simp{}_D0={:.4f}_Npart={}_B0={:.4f}_Vquad0={:.4f}_rtrunc={:.4f}.txt".format(d0, hermitian,simplify, D0,1,B0,Vquad0,rtrunc),np.transpose(out))
plt.show()









