import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.image as mpimg
import matplotlib.cbook as cbook
from scipy.special import erfi
#from pythtb import * # import TB model class



#system parameters
lambda0=1  # atomic transition wavelength
k0=2*np.pi/lambda0 # atomic transition wavenumber
d0=0.1 # dimensionless lattice constant
d=d0*lambda0 # full lattice constant
eta=0.1 
aho=eta*d # regularization (see supplementary material to https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.119.023603) 
onlyreal=False # calculate Hermitian part (True) or full Hamiltonian
B0=12 #rescaled magnetic field
fsize=18 #font size in the figure


B=B0/d**3/k0**3 # full magnetic field 
print("B=", B)
latvec=[[d,0], [d/2,np.sqrt(3)/2*d]]

orb=[[0,0], [1./3.,1./3.]] # positions of atoms in the unit cell, in the basis of lattice vectors
norb=len(orb)  # number of atoms per unit cell (note: each atom contains 2 orbitals)
onsite=[0,0]  # onsite potentials

#set plot label font
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['CMU Serif']
plt.rcParams['mathtext.fontset'] = 'cm'


def CellArea(latvec):
  # area of unit cell, given the lattice vectors
  v1,v2=latvec
  return np.abs(np.abs(np.cross(v1,v2)))

def ReciprocalVectors2D(latvec):
   # reciprocal lattice vector, given the real-space lattice vectors
   v1=np.array(latvec[0])
   v2=np.array(latvec[1])
   Q=np.zeros((2,2))
   Q[0,1]=1
   Q[1,0]=-1
   b1=2*np.pi*np.dot(Q,v2)/np.dot(v1, np.dot(Q,v2))
   b2=2*np.pi*np.dot(Q,v1)/np.dot(v2, np.dot(Q,v1))
   return b1,b2

def GreenMS(kvec,recvec, k0, cellarea,bvec):
  #Green's function in momentum space, see supplementary material to https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.119.023603
  #recvec - reciprocal lattice vectors
  #k0 - atomic transition wavenumber
  #cellarea - area of the unit cell
  #bvec - real-space position difference between the two sites 
  bx,by=bvec
  kx,ky=kvec
  kvec=np.array(kvec)
  recvec1=np.array(recvec[0])
  recvec2=np.array(recvec[1])
  extent=30
  extent1=extent
  extent2=extent
  G=np.zeros((2,2), dtype=complex)
  for i1 in range(-extent1, extent1+1):
    for i2 in range(-extent2, extent2+1):
      gvec=i1*recvec1+i2*recvec2
      qvec=gvec-kvec
      lamsq=k0**2-np.linalg.norm(qvec)**2
      G0=np.zeros((2,2), dtype=complex)
      if(lamsq>=0):
        lam=np.sqrt(lamsq)
      else:
        lam=1j*np.sqrt(-lamsq)
      for i in range(2):
        G0[i,i]=k0**2
      for i in range(2):
        for j in range(2):
          G0[i,j]=G0[i,j]-qvec[i]*qvec[j]
      G0=G0*(1j-erfi(aho*lam/np.sqrt(2)))*np.exp(-1j*np.dot(bvec, qvec))
      G=G+G0/lam
  G=G/(2*k0**2*cellarea)
  return G

def G0(k0, aho):
  # regularized Green's function at the source, see supplementary material to https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.119.023603
  return k0/(6*np.pi)*( (erfi(k0*aho/np.sqrt(2))-1j)/np.exp((k0*aho)**2/2)-
(-0.5+(k0*aho)**2)/(np.sqrt(0.5*np.pi)*(k0*aho)**3)
          )

def SimplifiedGXYBasis(rvec, k0):
  #Green's function of the simplified model (defined in real space)
  # rvec - difference between positions of sites 
  # k0 - atomic transition wavenumber
  x,y=rvec
  r=np.sqrt(x**2+y**2)
  phi=np.arctan2(y,x)
  G=np.zeros((2,2), dtype=complex)
  prefactor=1./(4*np.pi*r**3)/(k0**2)
  for i in range(2):
    for j in range(2):
      delta=float(i==j)
      dirfac=rvec[i]*rvec[j]/(r**2)
      G[i,j]=prefactor*(1*delta-3*dirfac)
  return G


def SimplifiedHamiltonian(kvec, k0, latvec, orb, onsite, B):
  # simplified model Hamiltonian
  # kvec - reciprocal-space point at which to calculate the Hamiltonian
  # k0 - atomic transition wavenumber
  # latvec - real-space lattice vectors
  # orb - positions of sites in the unit cell in the basis of lattice vectors
  # onsite - onsite potentials
  # B - magnetic field (not rescaled) 
  latvec1=np.array(latvec[0])
  latvec2=np.array(latvec[1])
  norb=len(orb)
  H=np.zeros((2*norb,2*norb), dtype=complex)
  extent=10
  for i in range(norb):
    for j in range(norb):
      pos1=latvec1*orb[i][0]+latvec2*orb[i][1]
      for d1 in range(-extent,extent+1):  
        for d2 in range(-extent, extent+1):  
          pos2=latvec1*(orb[j][0]+d1)+latvec2*(orb[j][1]+d2)
          rvec=pos1-pos2
          if(np.dot(rvec,rvec)>0.00001):
            G=SimplifiedGXYBasis(rvec, k0)
            H[2*i:(2*i+2),2*j:(2*j+2)]+=3*np.pi/k0*G*np.exp(-1j*np.dot(kvec,rvec))
      if (i==j):
        H[2*i,2*i+1]-=1j*B
        H[2*i+1,2*i]+=1j*B
  return H 


def Hamiltonian(kvec,recvec, k0, cellarea, latvec, orb, onsite, B):
  # full model Hamiltonian in the momentum point kvec
  # recvec - reciprocal lattice vectors
  # k0 - atomic transition
  # cellarea - unit cell area
  # latvec - lattice vectors in real space
  # orb - positions of sites in the unit cell in the basis of lattice vectors
  # onsite - onsite potentials
  # B - magnetic field (not rescaled) 
  Gshift=G0(k0, aho)
  latvec1=np.array(latvec[0])
  latvec2=np.array(latvec[1])
  norb=len(orb)
  H=np.zeros((2*norb,2*norb), dtype=complex)
  for i in range(norb):
    for j in range(norb):
      pos1=latvec1*orb[i][0]+latvec2*orb[i][1]
      pos2=latvec1*orb[j][0]+latvec2*orb[j][1]
      bvec=pos1-pos2
      G=GreenMS(kvec,recvec, k0, cellarea, bvec)
      H[2*i:(2*i+2),2*j:(2*j+2)]=-3*np.pi/k0*G*np.exp(-1j*np.dot(kvec,bvec))
      if (i==j):
        H[2*i:(2*i+2),2*i:(2*i+2)]+=np.identity(2)*onsite[i]
        H[2*i,2*i+1]+=1j*B
        H[2*i+1,2*i]-=1j*B
  H=H-np.identity(2*norb)*(0.5j+3*np.pi/k0*Gshift)
  if(onlyreal):
     H=(H+np.conj(np.transpose(H)))/2.
  return H 

def SpecialPoints(lattice_type, recvec):
   #creates a dictionary of special points in momentum space
   #lattice_type - "hex" or "square"
   #recvec - reciprocal vectors
   recvec1=np.array(recvec[0])
   recvec2=np.array(recvec[1])
   if(lattice_type=="hex"):
     return {r"$\Gamma$":np.array([0,0]),r"$M$":recvec1/2, r"$K$":recvec1*2/3+recvec2/3}
   if(lattice_type=="square"):
     return {r"$Gamma$":np.array([0,0]),r"$X$":recvec1/2, r"$M$":recvec1/2+recvec2/2}

def CreateMomentumPath(pointnames, numbers, specpoints):
    #creates a path between special points in momentum space
    #pointnames - list of special points setting the path
    #numbers - number of points between each pair of special points
    #specpoints - dictionary - positions of special points
    npoints=len(pointnames)
    kpoints=np.array([[],[]])
    for ip in range(npoints-1):
      start=specpoints[pointnames[ip]]
      end=specpoints[pointnames[ip+1]]
      n=numbers[ip]
      kpoints_temp=np.outer(end-start,np.linspace(0,1,n, endpoint=False))+np.outer(start, np.ones(n))
      kpoints=np.concatenate([kpoints,kpoints_temp], axis=1)
    kpoints=np.concatenate([kpoints, np.reshape(end, (2,1))], axis=1)
    return(kpoints)

# get cell area
cellarea=CellArea(latvec)
print("Cell area:",cellarea)
# get reciprocal vectors
recvec=ReciprocalVectors2D(latvec)
recvec1=np.array(recvec[0])        
recvec2=np.array(recvec[1])
print("reciprocal vectors:", recvec)

#define subplots
fig=plt.figure(figsize=(7,4))
spec = gridspec.GridSpec(ncols=6, nrows=3, figure=fig, left=0, right=1, bottom=0, top=0.975, wspace=0, hspace=0, height_ratios=[0.05, 0.95,0.2], width_ratios=[1.2, 0.45, 1, 0.05, 0.08, 0.35])

ax_main0=fig.add_subplot(spec[1:, 0])
ax_main1=fig.add_subplot(spec[:2, 2])

# get the path in the momentum space
specialpoints=SpecialPoints("hex", recvec)
nks=[100,100,100] #number of data points between each pair of special points
nk=sum(nks)+1
pointlabels=[r"$\Gamma$", r"$K$", r"$M$", r"$\Gamma$"]
kpath=CreateMomentumPath(pointlabels, nks, specialpoints)
kpath=kpath+0.000001


# get locations of the vertical lines in the plot denoting the special points
kdist=0
kpath_dist=[kdist]
for ik in range(nk-1):
  kvec1=kpath[:,ik]
  kvec2=kpath[:,ik+1]
  kdist+=np.linalg.norm(kvec2-kvec1)
  kpath_dist.append(kdist)


# loop over path in momentum space, obtain energies
Es=[]
Es2=[]
refs=[]
ks2=[]
Gs=[]
for ik in range(nk):
  kvec=kpath[:,ik]
  kx,ky=kvec
  print("k=",kx, ky)
  H=Hamiltonian(kvec,recvec, k0, cellarea, latvec, orb, onsite, B)
  E=np.linalg.eigvals(H)
  ind=np.argsort(np.real(E))
  E=E[ind]
  Es.extend(E)
  ks2.extend(np.ones(len(E))*kpath_dist[ik])
  H=SimplifiedHamiltonian(kvec, k0, latvec, orb, onsite, B)
  E=np.linalg.eigvalsh(H)
  Es2.extend(E)
     
# plot spectra
ax_main1.scatter(ks2, Es2, c="lightgray", s=1, zorder=5)
im=ax_main1.scatter(ks2, np.real(Es), c=-2*np.imag(Es), vmax=3, s=1, zorder=6)
ax_main1.set_ylabel(r"$\mathrm{Re}(E)/\Gamma_0$", fontsize=fsize, labelpad=-10)
ax_main1.set_xlabel(r"$k$ along the path", fontsize=fsize)

# plot vertical lines denoting special points and add x ticks
nk1=0
nks2=[0]
for ik,nk0 in enumerate(nks):
    ax_main1.axvline(kpath_dist[nk1])
    nk1=nk1+nk0
    nks2.append(kpath_dist[nk1])
ax_main1.axvline(kpath_dist[nk1])
ax_main1.set_xticks(nks2)
ax_main1.set_xticklabels(pointlabels)


#add colorbar
axcb=fig.add_subplot(spec[:2, 4])
plt.colorbar(im, cax=axcb)
axcb.set_ylabel(r"$\Gamma / \Gamma_0$", labelpad=-15, fontsize=fsize)
ticklabels=axcb.get_yticklabels()
ticklabels[-1]=r"$\geq$"+ticklabels[-1].get_text()
axcb.set_yticklabels(ticklabels)
axcb.tick_params(axis='both', which='major', labelsize=fsize)

# color the light cone in grey
ax_main1.axvspan(0,2*np.pi, alpha=1, color='gray', zorder=1)
ax_main1.axvspan(np.max(ks2)-2*np.pi, np.max(ks2), alpha=1, color='gray', zorder=1)


ax_main1.set_ylim(-150,100) # set y limits
ax_main1.tick_params(axis='both', which='major', labelsize=fsize) # set tick labels font size

# add inset with the Brillouin zone
image = mpimg.imread("brillouin.png")
axins=ax_main1.inset_axes((0.32, 0.35, 0.4, 0.5))
axins.imshow(image, rasterized=True)
axins.set_axis_off()

# plot the image in subplot (a)
image = mpimg.imread("flake3.png")
ax_main0.imshow(image, rasterized=True)
ax_main0.set_axis_off()

# subplot labels
ax_main0.text(-0.0, 1.015, "(a)", transform=ax_main0.transAxes, fontsize=fsize)
ax_main1.text(-0.4, 0.96, "(b)", transform=ax_main1.transAxes, fontsize=fsize)

#save the file and show the plot
plt.savefig("Figure1.pdf",dpi=300)
plt.show()
