#!/usr/bin/env python

from __future__ import print_function
from pythtb import * # import TB model class
import numpy as np
import matplotlib.pyplot as plt


lambda0=1 # atomic wavelength
k0=2*np.pi/lambda0 # atomic wavenumber
d=lambda0 #lattice constant (does not matter for the result except from a multiplicative constant for energy spectrum)
square_step=31 # number of points for Chern number and band structure calculations (square_step x square_step)

# define lattice vectors
lat=[[d,0], [d/2,np.sqrt(3)/2*d]]

# define atom position in the unit cell (in the units of lattice vectors)
orb0=[[0,0]]
orb=[]
# every atom has two orbitals
for o in orb0:
  for i in range(2):
    orb.append(np.array(o)+0.0*i*np.array([1,1]))
nband=len(orb)

extentX=5 # how many unit cells to include in the calculation of dipolar hopping: (2*extentX+1) x (2*extentY+1)
extentY=extentX

B0s=np.arange(0,1,0.1) #values of B0 (rescaled magnetic field) for calculation
cherns=[] # list of Chern numbers of the lowest band
flatness_ratios_up=[] # list of flatness ratios of the upper band
flatness_ratios_dn=[] # list of flatness ratios of the lower band
for B0 in B0s:
  B=B0/d**3/k0**3 # find the non-rescaled version of magnetic field
  my_model=tb_model(2,2,lat,orb) # define the tight-binding model
  # Define onsite matrix elements (in the circular basis)
  onsite0=[0,0]
  onsite=[]
  for i in range(len(orb0)):
    onsite.append(-B+onsite0[i])
    onsite.append(B+onsite0[i])
  my_model.set_onsite(onsite)
  # Define hopping matrix elements
  for ix in range(0, extentX+1):
    for iy in range(-extentY, extentY+1):
      for io1, o1 in enumerate(orb0):
        for io2, o2 in enumerate(orb0):
          if(iy>0)or(ix>0)or((ix==0)and(iy==0)and(io2>io1)):
            rvec1=o1[0]*np.array(lat[0])+o1[1]*np.array(lat[1])
            rvec2=(o2[0]+ix)*np.array(lat[0])+(o2[1]+iy)*np.array(lat[1])
            rvec=rvec2-rvec1
            phi=np.arctan2(rvec[1],rvec[0]) 
            r=np.sqrt(rvec[0]**2+rvec[1]**2)
            pref=-1./(8*np.pi*r**3)/(k0**2)
            pref*=3*np.pi/k0
            my_model.set_hop(pref, 2*io1+0, 2*io2+0, [ ix, iy])
            my_model.set_hop(pref, 2*io1+1, 2*io2+1, [ ix, iy])
            my_model.set_hop(3*np.exp(2j*phi)*pref, 2*io1+0, 2*io2+1, [ ix, iy])
            my_model.set_hop(3*np.exp(-2j*phi)*pref, 2*io1+1, 2*io2+0, [ ix, iy])

  # construct two-dimensional square patch covering the whole Brillouin zone
  square_length=1
  # two-dimensional wf_array to store wavefunctions on the path
  w_square=wf_array(my_model,[square_step,square_step])
  all_kpt=np.zeros((square_step,square_step,2))
  evals=np.zeros((square_step, square_step, nband))
  # now populate array with wavefunctions
  for i in range(square_step):
      for j in range(square_step):
          # construct k-point on the square patch
          kpt=np.array([square_length*(float(i)/float(square_step-1)),
                        square_length*(float(j)/float(square_step-1))])        
          # store k-points for plotting
          all_kpt[i,j,:]=kpt
          # find eigenvectors at this k-point
          (eval,evec)=my_model.solve_one(kpt,eig_vectors=True)
          # store eigenvector into wf_array object
          w_square[i,j]=evec
          evals[i,j,:]=eval

  # compute Berry flux on this square patch
  print("B0=", B0)
  print("Chern number:")
  C=w_square.berry_flux([0])/(2*np.pi)
  print("  for band ", 0, " equals    : ", C)
  cherns.append(C)
  #compute flatness ratios
  bandwidth_dn=np.max(evals[:,:,0])-np.min(evals[:,:,0])
  bandwidth_up=np.max(evals[:,:,1])-np.min(evals[:,:,1])
  gap=np.min(evals[:,:,1])-np.max(evals[:,:,0])
  flatness_ratios_up.append(gap/bandwidth_up)
  flatness_ratios_dn.append(gap/bandwidth_dn)
# Plot data
fig, ax=plt.subplots(1,2)
ax[0].plot(B0s, cherns)
ax[1].plot(B0s, flatness_ratios_up, label="Upper band")
ax[1].plot(B0s, flatness_ratios_dn, label="Lower band")
ax[0].set_ylabel("Chern number of the lower band")
ax[0].set_xlabel(r"$B_0$")
ax[1].set_ylabel("Flatness ratio")
ax[1].set_xlabel(r"$B_0$")
ax[1].legend()

plt.tight_layout()
plt.savefig("triangular.pdf")
plt.show()

