#!/usr/bin/env python

from __future__ import print_function
from pythtb import * # import TB model class
import numpy as np
import matplotlib.pyplot as plt


lambda0=1 # atomic wavelength
k0=2*np.pi/lambda0 # atomic wavenumber
d=lambda0 #lattice constant (does not matter for the result except from a multiplicative constant for energy spectrum)
square_step=31 # the grid for calculation of Chern number and band structure has square_step x square_step points

# define lattice vectors
lat=[[d,0.0],[0.0, d]]

#define atom positions in the unit cell (in the units of lattice vectors)
orb0=[[0,0], [0.5,0.5]]
# every atom has two orbitals
orb=[]
for o in orb0:
  for i in range(2):
    orb.append(np.array(o))
nband=len(orb)


extentX=5 # how many unit cells to include in the calculation of dipolar hopping: (2*extentX+1) x (2*extentY+1)
extentY=extentX

B0s=np.arange(0, 11, 1) #values of B0 (rescaled magnetic field) for calculation
V0s=np.arange(0, 21, 1) #values of V0 (rescaled onsite potential on second atom in the unit cell) for calculation

cherns=np.zeros((len(B0s), len(V0s), len(orb)))  #array for chern numbers
flatness_ratios_all=np.zeros((len(B0s), len(V0s), len(orb))) # array for flatness ratios

for iB0, B0 in enumerate(B0s):
  for iV0, V0 in enumerate(V0s):
    # find the non-rescaled versions of magnetic field and onsite potenttial
    B=B0/d**3/k0**3
    V=V0/d**3/k0**3

    my_model=tb_model(2,2,lat,orb)

    # Define onsite matrix elements (in the circular basis)
    onsite0=[0,V]
    onsite=[]
    for i in range(len(orb0)):
      onsite.append(-B+onsite0[i])
      onsite.append(B+onsite0[i])
    my_model.set_onsite(onsite)
    
    # Define hoppings
    for ix in range(0, extentX+1):
      for iy in range(-extentY, extentY+1):
        for io1, o1 in enumerate(orb0):
          for io2, o2 in enumerate(orb0):
            if(iy>0)or(ix>0)or((ix==0)and(iy==0)and(io2>io1)):
              rvec1=o1[0]*np.array(lat[0])+o1[1]*np.array(lat[1])
              rvec2=(o2[0]+ix)*np.array(lat[0])+(o2[1]+iy)*np.array(lat[1])
              rvec=rvec2-rvec1
              phi=np.arctan2(rvec[1],rvec[0]) 
              r=np.sqrt(rvec[0]**2+rvec[1]**2)
              pref=-1./(8*np.pi*r**3)/(k0**2)
              pref*=3*np.pi/k0
              my_model.set_hop(pref, 2*io1+0, 2*io2+0, [ ix, iy])
              my_model.set_hop(pref, 2*io1+1, 2*io2+1, [ ix, iy])
              my_model.set_hop(3*np.exp(2j*phi)*pref, 2*io1+0, 2*io2+1, [ ix, iy])
              my_model.set_hop(3*np.exp(-2j*phi)*pref, 2*io1+1, 2*io2+0, [ ix, iy])
             
   
    # construct two-dimensional square patch covering the entire Brillouin zone
    square_length=1
    # two-dimensional wf_array to store wavefunctions on the path
    w_square=wf_array(my_model,[square_step,square_step])
    all_kpt=np.zeros((square_step,square_step,2))
    evals=np.zeros((square_step, square_step, nband))
    # now populate array with wavefunctions
    for i in range(square_step):
        for j in range(square_step):
            # construct k-point on the square patch
            kpt=np.array([square_length*(float(i)/float(square_step-1)),
                          square_length*(float(j)/float(square_step-1))])        
            # store k-points for plotting
            all_kpt[i,j,:]=kpt
            # find eigenvectors at this k-point
            (eval,evec)=my_model.solve_one(kpt,eig_vectors=True)
            # store eigenvector into wf_array object
            w_square[i,j]=evec
            evals[i,j,:]=eval
   
    
    # compute Berry flux on this square patch
    print("B0=", B0, "V0=", V0)
    print("Chern numbers:")
    for i in range(len(orb)):
      C=w_square.berry_flux([i])/(2*np.pi)
      print("  for band ", i, " equals    : ", C)
      cherns[iB0, iV0, i]=C
      # Compute flatness ratio for each band. The smaller value of the two flatness ratios (for gap above and below the band) is taken
      bandwidth=np.max(evals[:,:,i])-np.min(evals[:,:,i])
      flatness_ratios=[]
      if (i<len(orb)-1):
         flatness_ratios.append((np.min(evals[:,:,i+1])-np.max(evals[:,:,i]) )/bandwidth)
      if (i>0):
         flatness_ratios.append((np.min(evals[:,:,i])-np.max(evals[:,:,i-1]) )/bandwidth)
      flatness_ratios_all[iB0,iV0,i]=np.min(flatness_ratios)

#Plot results (in a separate file for each band):
dB0=B0s[1]-B0s[0]
dV0=V0s[1]-V0s[0]
extent=[V0s[0]-dV0/2, V0s[-1]+dV0/2, B0s[0]-dB0/2, B0s[-1]+dB0/2]
for i in range(len(orb)):
  fig,ax=plt.subplots(1,2, figsize=(9,4))
  im1=ax[0].imshow(cherns[:,:,i], origin="lower", extent=extent, vmin=-2, vmax=2) # For better visibility, the Chern number scale is limited to [-2, 2]
  im2=ax[1].imshow(flatness_ratios_all[:,:,i], origin="lower", extent=extent)
  ax[0].set_xlabel(r"$V_0$")
  ax[0].set_ylabel(r"$B_0$")
  ax[1].set_xlabel(r"$V_0$")
  ax[1].set_ylabel(r"$B_0$")
  ax[0].set_aspect("auto")
  ax[1].set_aspect("auto")
  cb1=plt.colorbar(im1, ax=ax[0])
  cb2=plt.colorbar(im2, ax=ax[1])
  cb1.set_label("Chern number")
  cb2.set_label("Flatness ratio")
  plt.tight_layout()
  plt.savefig("square_band{}.pdf".format(i))
plt.show()
