# Preliminary

# Set environment variables - not mandatory but recommended if you use Parallel()
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

import wannierberri_OSD.wannierberri as wberri
from wannierberri_OSD.wannierberri import calculators as calc
print (f"Using WannierBerri version {wberri.__version__}")
import numpy as np
import scipy
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from termcolor import cprint
from matplotlib.cm import ScalarMappable
from matplotlib.lines import Line2D

# Choose one of the options

#parallel = wberri.Parallel(num_cpus=32)
parallel = wberri.Parallel()  # Automatic detection
#parallel = wberri.Serial()

# Input data
Kmin = 8 # Lower and upper band on loop on convergence vs k points in the pw calculation
Kmax = 8
hK = 1  # delta k 
nbnd = 18

# Importing data from wannier90

seed_pw2wan  = 'Se'
    
# Data obtained from pw2wan

systems  = wberri.System_w90(seedname = seed_pw2wan,
                             OSD                = True,
                             use_wcc_phase      = True,
                             use_ws             = True,
                             transl_inv         = False, 
                             transl_inv_JM      = True,
                             berry              = True,
                             wcc_phase_fin_diff = False
                             )

# Symmetrize system 

pos = np.array([[ 0.217,  0.217,  0.0000000],
                [-0.217,  0.000,  0.333333333333333333333333333333333],
                [ 0.000, -0.217,  0.666666666666666666666666666666667],
                [-0.409, -0.438,  0.2149101],
                [ 0.438,  0.029,  0.5482434],
                [-0.029,  0.409,  0.8815767],
                [-0.438, -0.409, -0.2149101],
                [ 0.029,  0.438,  0.4517565],
                [ 0.409, -0.029,  0.1184232]])

at = ['Se', 'Se', 'Se', 'fict', 'fict', 'fict', 'fict', 'fict', 'fict']

proj_wf = ['Se:s', 'Se:p', 'fict:s']

systems.symmetrize(positions = pos,
                   atom_name = at,
                   proj = proj_wf,
                   soc = False,
                   magmom = None,
                   DFT_code = 'qe'
                  )

#systems.set_symmetry(["C3z", "C2z", "Mx", "TimeReversal"])

# Fixed ab initio grid, various interpolation grids
NK_min    = 50
NK_max    = 50
NK_h      = 50
val_Kgrid = Kmax 
#for val_NK in range(NK_min,NK_max+NK_h,NK_h):
    
# Defining the calculators
calculators = {}
Efermi = np.array([ 7.7362 ]) # eV
omega = np.linspace(0,1.5,1000)  # eV
T = 10                          # K
kBT = 8.617333262e-5 * T        # eV
eta = 0.025                     # eV

np.savetxt('omega.txt', omega)

calculators["SDCT_asym_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=True, E2_terms=True, V_terms=True,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_M1_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=True, E2_terms=False, V_terms=False,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_E2_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=False, E2_terms=True, V_terms=False,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_V_all"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=False, E2_terms=False, V_terms=True,  kwargs_formula={'external_terms': True})

calculators["SDCT_asym_int"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=True, E2_terms=True, V_terms=True,  kwargs_formula={'external_terms': False})

calculators["SDCT_asym_M1_int"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=True, E2_terms=False, V_terms=False,  kwargs_formula={'external_terms': False})

calculators["SDCT_asym_E2_int"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=False, E2_terms=True, V_terms=False, kwargs_formula={'external_terms': False})

calculators["SDCT_asym_V_int"]  = wberri.calculators.dynamic.SDCT_asym(Efermi=Efermi, omega=omega, kBT=kBT, smr_type='Gaussian', smr_fixed_width=eta, fermi_sea=True, fermi_surf=False, M1_terms=False, E2_terms=False, V_terms=True,  kwargs_formula={'external_terms': False})

# Running the calculators for the grids_in

from scipy.constants import elementary_charge, hbar
for val_NK in range(NK_min,NK_max+NK_h,NK_h):
    results_grid_in = {}
    grids_in = {}
    data_in_rho = np.zeros((len(omega),4,2,3,3,3), dtype=float)
    grids_in = wberri.Grid(systems,
                           NK=[val_NK,val_NK,val_NK],
                           NKFFT=[val_Kgrid,val_Kgrid,val_Kgrid]
                          )
    results_grid_in = wberri.run(systems,
                                 grid = grids_in,
                                 calculators = calculators,
                                 parallel = parallel,
                                 print_Kpoints = False,
                                 adpt_num_iter = 0,
                                 fout_name = 'Se',
                                 restart = False
                                )
    data_in_rho[:,0,0]  = np.real(results_grid_in.results["SDCT_asym_all"].data[0])
    data_in_rho[:,1,0]  = np.real(results_grid_in.results["SDCT_asym_M1_all"].data[0])
    data_in_rho[:,2,0]  = np.real(results_grid_in.results["SDCT_asym_E2_all"].data[0])
    data_in_rho[:,3,0]  = np.real(results_grid_in.results["SDCT_asym_V_all"].data[0])
    data_in_rho[:,0,1]  = np.real(results_grid_in.results["SDCT_asym_int"].data[0])
    data_in_rho[:,1,1]  = np.real(results_grid_in.results["SDCT_asym_M1_int"].data[0])
    data_in_rho[:,2,1]  = np.real(results_grid_in.results["SDCT_asym_E2_int"].data[0])
    data_in_rho[:,3,1]  = np.real(results_grid_in.results["SDCT_asym_V_int"].data[0])
    data_in_rho *= (hbar / elementary_charge**2)             # Dimensionless SDCT
    data_in_rho *= (1.0 / 137.035999084)                     # Fine structure constant
    data_in_rho *= (2.0 * np.pi / 299792458)                 # 2pi/c
    data_in_rho *= omega[:,None,None,None,None,None] / 6.582119569e-16 # Frequency in s⁻¹
    data_in_rho *= (360 / (2.0 * np.pi))                     # Radians to degrees
    data_in_rho *= (1.0/1000.0)                              # Per meter to per milimeter
    for ia in range(0,3): 
        for ib in range(0,3): 
            for ic in range(0,3):                     
                filename = "RP_k" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_1 = "RP_k_M1_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_2 = "RP_k_E2_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_3 = "RP_k_V_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_all.txt"
                filename_4 = "RP_k" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_int.txt"
                filename_5 = "RP_k_M1_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_int.txt"
                filename_6 = "RP_k_E2_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_int.txt"
                filename_7 = "RP_k_V_" + str(val_NK) + "_" + str(ia) + str(ib) + str(ic) + "_int.txt"
                np.savetxt(filename,np.column_stack((omega,data_in_rho[:,0,0,ia,ib,ic])))
                np.savetxt(filename_1,np.column_stack((omega,data_in_rho[:,1,0,ia,ib,ic])))
                np.savetxt(filename_2,np.column_stack((omega,data_in_rho[:,2,0,ia,ib,ic])))
                np.savetxt(filename_3,np.column_stack((omega,data_in_rho[:,3,0,ia,ib,ic])))
                np.savetxt(filename_4,np.column_stack((omega,data_in_rho[:,0,1,ia,ib,ic])))
                np.savetxt(filename_5,np.column_stack((omega,data_in_rho[:,1,1,ia,ib,ic])))
                np.savetxt(filename_6,np.column_stack((omega,data_in_rho[:,2,1,ia,ib,ic])))
                np.savetxt(filename_7,np.column_stack((omega,data_in_rho[:,3,1,ia,ib,ic])))
print(" ---------------------------------------------------------------------- ")
