import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import inspect
from string import ascii_lowercase

# import ED functions
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 
import NonHermitianAngularMomentumED as ed

#set plot fonts
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'

########################################################################
#  diagonalize the D_0=2.1 systems
########################################################################

lambda0=1 # atomic transiton wavelength
k0=2*np.pi/lambda0 #atomic transition wavenumber 
d0=0.05 # 
d=d0*lambda0
D0=2.1 # system size
D=D0*d 
Vquad0=3 # rescaled harmonic potential
Lmax=6 # Lmax-fold rotational symmetry
Npart=1 # number of particles
rtrunc=-1 # truncation radius of the interactions (-1 if none) 
fsize=20.5 # font size in plot labels
B0=12 #rescaled magnetic field


B=B0/d**3/k0**3 # non-rescaled magnetic field
Vquad=Vquad0/d**5/k0**3
latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ] # lattice vectors
V=0
orb=[[1./3.,1./3.], [2./3.,2./3.]] # positions of atoms in the unit cell (in the basis of lattice vectors)


# generate lattice positions
xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d) # positions of sites within sector 0
Nseg=len(xs) #number of sites in a sector
Nsite=Nseg*Lmax # total number of sites
xall=[]
yall=[]
for i in range(Lmax):
  xres,yres=ed.RotateAll(xs,ys, i, Lmax) # rotate sector 0 to get other sectors
  xall.extend(xres)
  yall.extend(yres)
pos=np.transpose(np.array([xall,yall])) 
Nsite=len(pos)

basis=ed.GenerateBasis(Npart, Nsite) # generate basis of configurations of sites (no symmetry)
Ns=2*Nsite
interactions2=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax, True) # interactions with transform
mombasis = ed.FindAngularMomentumEigenstates(basis, Npart, Lmax) # generate basis of angular momentum eigenstates

# ----- loop over L - diagonalize Hamiltonian with transform  -----
evecsite_all=[]
evals_all=[]
Ls_all=[]
nevals=[]
numbers_all=[]
evecsL=[]
evals0=[]
evals1=[]
for L in range(Lmax):
  Hherm=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions2,basis,Lmax, mombasis) # generate the Hamiltonian
  Hdense=np.array(Hherm.todense())
  eval, evec=np.linalg.eig(Hdense)      # diagonalize the Hamiltonian
  nevals.append(len(eval))
  arg=np.argsort(np.real(eval)) #sort eigenvalues and eigenvectors
  eval=eval[arg]
  evec=evec[:,arg]
  evals_all.extend(eval)
  evecsL.append(evec)
  evecsite=ed.ConvertEigenstates(evec, basis, mombasis, L,Lmax) # convert eigenstates to the site basis
  evecsite=ed.OriginalPhases(evecsite, basis, Lmax, Nsite) # undo the transform
  evecsite_all.append(evecsite)
  Ls_all.extend(np.ones(len(eval))*L) #angular momenta
  numbers_all.extend(range(len(eval))) #eigenstate index
  evals0.append(eval[0]) # save the lowest eigenvalue (for the black line on the plot)
  evals1.append(eval[1]) # save the second eigenvalue (for the black line on the plot)


# ----- compute expectation value of the antihermitian part  -----
interactions_antiherm=ed.AntiHermitianInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax) # interactions defining noh-hermitian hamiltonian
expval_all=[]
for L in range(Lmax):
  H_antiherm=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions_antiherm,basis,Lmax, mombasis) # get the non-hermitian Hamiltonian
  evec=evecsL[L]
  expval=[]
  for i in range(evec.shape[1]):
    vec=evec[:,i]
    expval.append(np.conj(vec) @ H_antiherm @ vec )
  expval_all.extend(expval)
rates_all=-2*np.imag(expval_all) #decay rates



########################################################################
#  create plot
########################################################################

#----- define subplots and spacings ------
fig11 = plt.figure(figsize=(8, 6.5))
outer_grid = fig11.add_gridspec(3, 1, wspace=0, hspace=0,height_ratios=[1,0.5, 1.2], right=1, left=0, bottom=0, top=0.98)
inner_grid1 = outer_grid[0, 0].subgridspec(1, 5, width_ratios=[0.18, 0.95, 0.20, 0.95, 0.03])
ax1=inner_grid1.subplots()

# ----- plot spectrum - subplot (a) -------
im=ax1[1].scatter(Ls_all, evals_all, c=rates_all, vmax=3, zorder=1) #plot spectrum
ax1[1].plot(range(6), evals0, zorder=0, c="k") #plot black lines
ax1[1].plot(range(6), evals1, zorder=0, c="k")
ax1[1].set_ylabel(r"$E/\Gamma_0$", fontsize=fsize, labelpad=-18)
ax1[1].set_xlabel(r"$L$", fontsize=fsize)


# ----- add colorbar to subplot (a) -------
cb=plt.colorbar(im, ax=ax1[1])
cb.ax.set_ylabel(r"$\Gamma/\Gamma_0$", labelpad=-25, fontsize=fsize)
cb.ax.tick_params(axis='both', which='major', labelsize=fsize)
# add "<=" in the last colorbar tick label
ticklabels=cb.ax.get_yticklabels()
ticklabels[-1]=r"$\geq$"+ticklabels[-1].get_text() 
cb.ax.set_yticklabels(ticklabels)


# ----- plot eigenstates - subplot (c) -------
inner_grid2 = outer_grid[2, 0].subgridspec(2, 7, width_ratios=[0.21,1,1,1,1,0.05, 0.40])
ax2=inner_grid2.subplots()
for L in range(4):
  ev=evecsite_all[L][:,0]
  for i in range(2): # loop over - and + orbitals
    ev_sub=ev[i::2] # get the components of the eigenvector
    im=ax2[i,L+1].scatter(pos[:,0], pos[:,1], s=np.abs(ev_sub)**2*1000, c=np.angle(ev_sub),cmap="hsv", edgecolor="k", vmin=-np.pi, vmax=np.pi) # plot the eigenstate
    ax2[i,L+1].set_axis_off() #turn off axes
    ax2[i,L+1].set_aspect("equal") # set aspect ratio to equal
  ax2[0,L+1].set_title(r"$m={}$".format(L), fontsize=fsize) # column labels
# row labels
ax2[0,1].text(-0.15, 0.5, r"$| -_i\rangle$", transform=ax2[0,1].transAxes, 
        horizontalalignment='center',
        verticalalignment='center',
        rotation='horizontal', fontsize=fsize)
ax2[1,1].text(-0.15, 0.5, r"$| +_i\rangle$", transform=ax2[1,1].transAxes, 
        horizontalalignment='center',
        verticalalignment='center',
        rotation='horizontal', fontsize=fsize)
#add colorbar
plt.colorbar(im, cax=ax2[0,5])
ax2[0,5].set_ylabel("phase", fontsize=fsize, labelpad=-15)
ax2[0,5].tick_params(axis='both', which='major', labelsize=fsize)

ax2[1,5].set_axis_off()

###########################################################3
# Diagonalize the D_0=0.6 system
############################################################
#(works exactly as for the D_0=2.1 system)

d0=0.3
d=d0*lambda0
D0=0.6
D=D0*d
Vquad0=0
Lmax=6

B0=12
B=B0/d**3/k0**3
Vquad=Vquad0/d**5/k0**3

latvec=[[d*np.sqrt(3)/2.,-d*0.5], [d*np.sqrt(3)/2.,d*0.5] ]

xs,ys=ed.GetLatticePointsSegment(D0, latvec, orb,d)
Nseg=len(xs)
Nsite=Nseg*Lmax
xall=[]
yall=[]
for i in range(Lmax):
  xres,yres=ed.RotateAll(xs,ys, i, Lmax)
  xall.extend(xres)
  yall.extend(yres)
pos=np.transpose(np.array([xall,yall]))
Nsite=len(pos)


basis=ed.GenerateBasis(Npart, Nsite)
Ns=2*Nsite
interactions2=ed.RealDipolarInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax, True)
mombasis = ed.FindAngularMomentumEigenstates(basis, Npart, Lmax)

evecsite_all=[]
evals_all=[]
Ls_all=[]
nevals=[]
numbers_all=[]
evecsL=[]
evals0=[]
evals1=[]
for L in range(Lmax):
  Hherm=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions2,basis,Lmax, mombasis)
  Hdense=np.array(Hherm.todense())
  eval, evec=np.linalg.eig(Hdense)
  nevals.append(len(eval))
  arg=np.argsort(np.real(eval))
  eval=eval[arg]
  evec=evec[:,arg]
  evals_all.extend(eval)
  evecsL.append(evec)
  evecsite=ed.ConvertEigenstates(evec, basis, mombasis, L,Lmax)
  evecsite=ed.OriginalPhases(evecsite, basis, Lmax, Nsite)
  evecsite_all.append(evecsite)
  Ls_all.extend(np.ones(len(eval))*L)
  numbers_all.extend(range(len(eval)))
  evals0.append(eval[0])
  evals1.append(eval[1])

#------ compute the expectation values of the antihermitian part -------------
interactions_antiherm=ed.AntiHermitianInteractionFiniteCircularFullTransform(pos, B, Vquad,k0, rtrunc, Lmax)
expval_all=[]
for L in range(Lmax):
  H_antiherm=ed.GenerateCRSHamiltonianWithAngularMomentum(L,interactions_antiherm,basis,Lmax, mombasis)
  evec=evecsL[L]
  expval=[]
  for i in range(evec.shape[1]):
    vec=evec[:,i]
    expval.append(np.conj(vec) @ H_antiherm @ vec )
  expval_all.extend(expval)
rates_all=-2*np.imag(expval_all)

#------ plot spectrum - subplot (b) -------------
im=ax1[3].scatter(Ls_all, evals_all, c=rates_all, vmax=3)
ax1[3].plot(range(6), evals0, zorder=0, c="k")
ax1[3].set_ylabel(r"$E/\Gamma_0$", fontsize=fsize, labelpad=-15)
ax1[3].set_xlabel(r"$L$", fontsize=fsize)
cb=plt.colorbar(im, ax=ax1[3])
cb.ax.set_ylabel(r"$\Gamma/\Gamma_0$", fontsize=fsize,labelpad=-5)
ticklabels=cb.ax.get_yticklabels()
ticklabels[-1]=r"$\geq$"+ticklabels[-1].get_text()
cb.ax.set_yticklabels(ticklabels)

cb.ax.tick_params(axis='both', which='major', labelsize=fsize)

ax1[3].tick_params(axis='both', which='major', labelsize=fsize)
ax1[1].tick_params(axis='both', which='major', labelsize=fsize)




#-------- finalize the plot -------
for i in [0,2,4]:
  ax1[i].set_axis_off()

for i,j in [[0,0], [1,0], [0,6], [1,6]]:
  ax2[i,j].set_axis_off()



ax1[1].text(0.05, 0.85, "(a)", transform=ax1[1].transAxes, fontsize=fsize, bbox=dict(facecolor='w', alpha=0.75))
ax1[3].text(0.05, 0.85, "(b)", transform=ax1[3].transAxes, fontsize=fsize, bbox=dict(facecolor='w',alpha=0.75))
ax2[0,1].text(-0.25, 0.95, "(c)", transform=ax2[0,1].transAxes, fontsize=fsize)



plt.figure(fig11)
plt.savefig("Figure2.pdf")
plt.show()
