# Preliminary

# Set environment variables - not mandatory but recommended if you use Parallel()
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

import wannierberri_OSD.wannierberri as wberri
from wannierberri_OSD.wannierberri import calculators as calc
print (f"Using WannierBerri version {wberri.__version__}")
import numpy as np
import scipy
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from termcolor import cprint
from matplotlib.cm import ScalarMappable
from matplotlib.lines import Line2D

# Choose one of the options

#parallel = wberri.Parallel(num_cpus=32)
parallel = wberri.Parallel()  # Automatic detection
#parallel = wberri.Serial()

# Input data
Kmin = 6 # Lower and upper band on loop on convergence vs k points in the pw calculation
Kmax = 6
hK = 1  # delta k 
nbnd = 16

# Importing data from wannier90

seed_pw2wan  = 'GaN'
    
# Data obtained from pw2wan

systems  = wberri.System_w90(seedname = seed_pw2wan,
                             OSD                = True,
                             use_wcc_phase      = True,
                             use_ws             = True,
                             transl_inv         = False, 
                             transl_inv_JM      = True,
                             berry              = True,
                             wcc_phase_fin_diff = False
                             )

# Symmetrize system 

pos = np.array([[0.333333333, 0.666666667, 0.9991755892],
                [0.666666667, 0.333333333, 0.4991755892],
                [0.333333333, 0.666666667, 0.3758244798], 
                [0.666666667, 0.333333333, 0.8758244798]])

at = ['Ga', 'Ga', 'N', 'N']

proj_wf = ['Ga:s;p', 'N:s;p']

systems.symmetrize(positions = pos,
                   atom_name = at,
                   proj = proj_wf,
                   soc = False,
                   magmom = None,
                   DFT_code = 'qe'
                  )

systems.set_symmetry(["C3z", "C2z", "Mx", "TimeReversal"])

# Fixed ab initio grid, various interpolation grids
NK_min    = 300
NK_max    = 300
NK_h      = 20
val_Kgrid = Kmax 
grids_in = {}
for val_NK in range(NK_min,NK_max+NK_h,NK_h):
    grids_in[val_NK] = wberri.Grid(systems,
                                   NK=[val_NK,val_NK,round(val_NK*4./6.)],
                                   NKFFT=[val_Kgrid,val_Kgrid,round(val_Kgrid*4./6.)]
                                  )
    
# Defining the calculators
calculators = {}
Efermi = np.array([ 11.0758 ]) # eV
omega = np.linspace(0,3.0,201)  # eV
T = 10                          # K
kBT = 8.617333262e-5 * T        # eV
eta = 0.04                       # eV

np.savetxt('omega.txt', omega)

calculators["SDCT_asym_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=True, E2_terms=True, V_terms=True,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_M1_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=True, E2_terms=False, V_terms=False,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_E2_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=False, E2_terms=True, V_terms=False,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_V_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=False, E2_terms=False, V_terms=True,  kwargs_formula={'external_terms': True})

# Running the calculators for the grids_in
results_grid_in = {}
for val_NK in range(NK_min,NK_max+NK_h,NK_h):
    results_grid_in[val_NK] = wberri.run(systems,
                                         grid = grids_in[val_NK],
                                         calculators = calculators,
                                         parallel = parallel,
                                         print_Kpoints = False,
                                         adpt_num_iter = 0,
                                         fout_name = 'GaN',
                                         restart = False
                                        )
print(" ---------------------------------------------------------------------- ")

# Cartesian components

from scipy.constants import elementary_charge, hbar
data_in_rho_imag = np.zeros((NK_max+NK_h,len(omega),4,2,3,3,3), dtype=float)
data_in_rho_real = np.zeros((NK_max+NK_h,len(omega),4,2,3,3,3), dtype=float)
for val_NK in range(NK_min,NK_max+NK_h,NK_h):
    data_in_rho_imag[val_NK,:,0,0]  = np.imag(results_grid_in[val_NK].results["SDCT_asym_all"].data[0])
    data_in_rho_imag[val_NK,:,1,0]  = np.imag(results_grid_in[val_NK].results["SDCT_asym_M1_all"].data[0])
    data_in_rho_imag[val_NK,:,2,0]  = np.imag(results_grid_in[val_NK].results["SDCT_asym_E2_all"].data[0])
    data_in_rho_imag[val_NK,:,3,0]  = np.imag(results_grid_in[val_NK].results["SDCT_asym_V_all"].data[0])
    data_in_rho_imag[val_NK] *= (hbar / elementary_charge**2)             # Dimensionless SDCT
    data_in_rho_imag[val_NK] *= (1.0 / 137.035999084)                     # Fine structure constant
    data_in_rho_imag[val_NK] *= (2.0 * np.pi / 299792458)                 # 2pi/c
    data_in_rho_imag[val_NK] *= omega[:,None,None,None,None,None] / 6.582119569e-16 # Frequency in s⁻¹
    data_in_rho_imag[val_NK] *= (360 / (2.0 * np.pi))                     # Radians to degrees
    data_in_rho_imag[val_NK] *= (1.0/1000.0)                              # Per meter to per m
    data_in_rho_real[val_NK,:,0,0]  = np.real(results_grid_in[val_NK].results["SDCT_asym_all"].data[0])
    data_in_rho_real[val_NK,:,1,0]  = np.real(results_grid_in[val_NK].results["SDCT_asym_M1_all"].data[0])
    data_in_rho_real[val_NK,:,2,0]  = np.real(results_grid_in[val_NK].results["SDCT_asym_E2_all"].data[0])
    data_in_rho_real[val_NK,:,3,0]  = np.real(results_grid_in[val_NK].results["SDCT_asym_V_all"].data[0])
    data_in_rho_real[val_NK] *= (hbar / elementary_charge**2)             # Dimensionless SDCT
    data_in_rho_real[val_NK] *= (1.0 / 137.035999084)                     # Fine structure constant
    data_in_rho_real[val_NK] *= (2.0 * np.pi / 299792458)                 # 2pi/c
    data_in_rho_real[val_NK] *= omega[:,None,None,None,None,None] / 6.582119569e-16 # Frequency in s⁻¹
    data_in_rho_real[val_NK] *= (360 / (2.0 * np.pi))                     # Radians to degrees
    data_in_rho_real[val_NK] *= (1.0/1000.0)                              # Per meter to per milimeter
    for ia in range(0,3): 
        for ib in range(0,3): 
            for ic in range(0,3):
                filename_0 = "CD_k" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_1 = "CD_k_M1_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_2 = "CD_k_E2_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_3 = "CD_k_V_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_00 = "RP_k" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_01 = "RP_k_M1_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_02 = "RP_k_E2_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_03 = "RP_k_V_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                np.savetxt(filename_0,np.column_stack((omega,data_in_rho_imag[val_NK][:,0,0,ia,ib,ic])))
                np.savetxt(filename_1,np.column_stack((omega,data_in_rho_imag[val_NK][:,1,0,ia,ib,ic])))
                np.savetxt(filename_2,np.column_stack((omega,data_in_rho_imag[val_NK][:,2,0,ia,ib,ic])))
                np.savetxt(filename_3,np.column_stack((omega,data_in_rho_imag[val_NK][:,3,0,ia,ib,ic])))
                np.savetxt(filename_00,np.column_stack((omega,data_in_rho_real[val_NK][:,0,0,ia,ib,ic])))
                np.savetxt(filename_01,np.column_stack((omega,data_in_rho_real[val_NK][:,1,0,ia,ib,ic])))
                np.savetxt(filename_02,np.column_stack((omega,data_in_rho_real[val_NK][:,2,0,ia,ib,ic])))
                np.savetxt(filename_03,np.column_stack((omega,data_in_rho_real[val_NK][:,3,0,ia,ib,ic])))
