#!/usr/bin/env python
# coding: utf-8

""" 
functions for geneneting trials for AFQMC
Hung Q. Pham (hung.pham@bytedance.com)
"""

import os, sys
import numpy as np
from pyscf import fci, lib
from pyscf.tools import fcidump
from mrh.my_pyscf.mcscf import addons as las_addons
from afqmctools.utils.pyscf_utils import load_from_pyscf_chk_mol
from afqmctools.hamiltonian.mol import write_hamil_mol
from afqmctools.wavefunction.mol import write_qmcpack_wfn, write_wfn_mol
from afqmctools.utils.qmcpack_utils import write_xml_input
from afqmctools.utils.linalg import get_ortho_ao_mol

import h5py
from pyscf.lib.chkfile import load_chkfile_key, load
from pyscf.lib.chkfile import dump_chkfile_key, dump, save
from pyscf.lib.chkfile import load_mol, save_mol


qmcpack_directory = '/opt/tiger/quchem/qmcpack-3.14.0'


def gen_las_trial(filename, mol, mc, tol=0.005, verbose=False, chol=1e-5, ndets=None, cas=None):
    filename = 'qmc_'+filename    
    nalpha, nbeta = mc.nelecas
    nmo = mc.mo_coeff.shape[-1]
    ci_cas, nelec = las_addons.las2cas_civec(mc)
    
    # Extract ci expansion in the form of a tuple: (coeff, occ_a, occ_b).
    # Note the tol param which will return wavefunction elements with abs(ci) > tol.
    ci, occa, occb = zip(*fci.addons.large_ci(ci_cas[0], mc.ncas, nelec[0], tol=tol, return_strs=False))   
    ci = np.array(ci)
    occa = np.array(occa)
    occb = np.array(occb)
    
    # Sort the determinants by the magnitude of their weight:
    ixs = np.argsort(np.abs(ci))[::-1]
    ci = ci[ixs]
    occa = occa[ixs]
    occb = occb[ixs]
    
    # Truncate to N number of SDs:
    if ndets is not None and isinstance(ndets,int):
        ci = ci[:ndets]
        occa = occa[:ndets]
        occb = occb[:ndets]          
               
    # Reinsert the frozen core as the AFQMC simulation is not run using an active space
    # Reinsert the frozen core as the AFQMC simulation is not run using an active space
    if cas:
        nfrozen = nmo - cas[1]
        ncore = mc.ncore - nfrozen
        nelec = (mol.nelec[0] - nfrozen, mol.nelec[1] - nfrozen)
        nmo = cas[1]
    else:
        ncore = mc.ncore
        nelec = mol.nelec
    core = [i for i in range(ncore)]
    occa = [np.array(core + [o + ncore for o in oa]) for oa in occa]
    occb = [np.array(core + [o + ncore for o in ob]) for ob in occb]
    
    # Generate Hamiltonian and store the LASSCF wave function h5 file
    # HP: note that LASSCF object doesn't have chkfile. Why?
    scf_data = load_from_pyscf_chk_mol(mc.chkfile, 'mcscf')
    scf_data['mo_coeff'] = mc.mo_coeff
    scf_data['mo_occ'] = mc.mo_occ
    write_hamil_mol(scf_data, filename+".h5", chol, verbose=verbose, real_chol=False, cas=cas)
    ci = np.array(ci, dtype=np.complex128)
    uhf = True # UHF always true for CI expansions.
    write_qmcpack_wfn(filename+".h5", (ci, occa, occb), uhf, nelec, nmo)
    
def gen_cas_trial(filename, mol, mc, tol=0.005, verbose=False, chol=1.e-5, ndets=None, cas=None):
    filename = 'qmc_'+filename    
    nalpha, nbeta = mc.nelecas
    nmo = mc.mo_coeff.shape[-1]
    
    # Extract ci expansion in the form of a tuple: (coeff, occ_a, occ_b).
    # Note the tol param which will return wavefunction elements with abs(ci) > tol.
    ci, occa, occb = zip(*fci.addons.large_ci(mc.ci, mc.ncas, (nalpha, nbeta), tol=tol, return_strs=False))  
    ci = np.array(ci)
    occa = np.array(occa)
    occb = np.array(occb)
    
    # Sort the determinants by the magnitude of their weight:
    ixs = np.argsort(np.abs(ci))[::-1]
    ci = ci[ixs]
    occa = occa[ixs]
    occb = occb[ixs]
    
    # Truncate to N number of SDs:
    if ndets is not None and isinstance(ndets,int):
        ci = ci[:ndets]
        occa = occa[:ndets]
        occb = occb[:ndets]          
               
    # Reinsert the frozen core as the AFQMC simulation is not run using an active space
    if cas:
        nfrozen = nmo - cas[1]
        ncore = mc.ncore - nfrozen
        nelec = (mol.nelec[0] - nfrozen, mol.nelec[1] - nfrozen)
        nmo = cas[1]
    else:
        ncore = mc.ncore
        nelec = mol.nelec
    core = [i for i in range(ncore)]
    occa = [np.array(core + [o + ncore for o in oa]) for oa in occa]
    occb = [np.array(core + [o + ncore for o in ob]) for ob in occb]
    
    # Generate Hamiltonian and store the CASSCF wave function h5 file
    scf_data = load_from_pyscf_chk_mol(mc.chkfile, 'mcscf')
    scf_data['mo_coeff'] = mc.mo_coeff
    scf_data['mo_occ'] = mc.mo_occ
    write_hamil_mol(scf_data, filename+".h5", chol, verbose=verbose, real_chol=False, cas=cas)
    ci = np.array(ci, dtype=np.complex128)
    uhf = True # UHF always true for CI expansions.
    write_qmcpack_wfn(filename+".h5", (ci, occa, occb), uhf, nelec, nmo)
    
def gen_input(filename, blocks=10000, nWalkers=20, rediag=False):
    filename = 'qmc_'+filename    
    options = {
        "Wavefunction": {
            "rediag": rediag,
        },
        "execute": {
            "timestep": 0.005,
            "blocks": blocks,
            "steps": 5,
            "substeps": 5,
            "nWalkers": nWalkers
        }
    }
    write_xml_input(filename+".xml", filename+".h5",filename+".h5",
                    options=options, rng_seed=7)
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename))      
    
def gen_uhf_trial(filename, chkfile, blocks=10000, nWalkers=20, chol=1e-5):
    filename = 'qmc_'+filename    
    os.system(f"{qmcpack_directory}/utils/afqmctools/bin/pyscf_to_afqmc.py -i {1} \
              -a -t {2} -q {0}.xml -o {0}.h5".format(filename,chkfile,chol))
    os.system("sed -i -e 's/\"steps\">10/\"steps\">5/g' {0}.xml".format(filename))
    os.system("sed -i -e '/\"steps\">*/a \ \ \ \ <parameter name=\"substeps\">5</parameter>' {0}.xml".format(filename))
    os.system("sed -i -e 's/\"nWalkers\">10/\"nWalkers\">{1}/g' {0}.xml".format(filename,nWalkers)) 
    os.system("sed -i -e 's/\"blocks\">10000/\"blocks\">{1}/g' {0}.xml".format(filename,blocks))      
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename)) 
    
def gen_rhf_trial(filename, chkfile, blocks=10000, nWalkers=20, chol=1e-5, cas=None):
    filename = 'qmc_'+filename
    os.system(f"{qmcpack_directory}/utils/afqmctools/bin/pyscf_to_afqmc.py -i {1} \
              -t {2} -q {0}.xml -o {0}.h5 -c {3},{4}".format(filename,chkfile,chol,cas[0],cas[1]))
    os.system("sed -i -e 's/\"steps\">10/\"steps\">5/g' {0}.xml".format(filename))
    os.system("sed -i -e '/\"steps\">*/a \ \ \ \ <parameter name=\"substeps\">5</parameter>' {0}.xml".format(filename))
    os.system("sed -i -e 's/\"nWalkers\">10/\"nWalkers\">{1}/g' {0}.xml".format(filename,nWalkers)) 
    os.system("sed -i -e 's/\"blocks\">10000/\"blocks\">{1}/g' {0}.xml".format(filename,blocks))      
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename))   
    
    
def gen_gamma_mf_trial(filename, cell, mf, chkfile, blocks=10000, nWalkers=20, ortho_ao=False, rediag=False, timestep=0.005, \
                       steps=5, substeps=5,LINDEP_CUTOFF=0):
    '''Assuming density fitting integral is used'''
    
    assert mf.with_df is not None
    filename = 'qmc_'+filename
    scf_data = load_from_pyscf_chk_mol(chkfile, 'scf')
    scf_data['mo_coeff'] = mf.mo_coeff
    scf_data['mo_occ'] = mf.mo_occ
    hcore = mf.get_hcore()
    scf_data['hcore'] = hcore
    scf_data['mol'] = cell

    s1e = mf.get_ovlp()
    scf_data['X'] = get_ortho_ao_mol(s1e)
            
    nao = hcore.shape[-1]
    Lpq = []
    for LpqR, LpqI, sign in mf.with_df.sr_loop(compact=False):
        Lpq.append(LpqR)
    scf_data['df_ints'] = np.concatenate(Lpq)
    
    write_hamil_mol(scf_data, filename + ".h5", 1e-5, ortho_ao=ortho_ao, verbose=False, real_chol=False, df=True)
    write_wfn_mol(scf_data, ortho_ao, filename + ".h5")
    
    options = {
        "Wavefunction": {
            "rediag": rediag,
        },
        "execute": {
            "timestep": 0.005,
            "blocks": blocks,
            "steps": 5,
            "substeps": 5,
            "nWalkers": nWalkers
        }
    }
    write_xml_input(filename+".xml", filename+".h5",filename+".h5",
                    options=options, rng_seed=7)
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename))
    
def gen_gamma_cas_trial(filename, cell, mf, mc, chkfile, tol=0.005, verbose=False, 
                        ndets=None, blocks=10000, nWalkers=20, rediag=False, timestep=0.005, \
                        steps=5, substeps=5):
    '''Assuming density fitting integral is used'''
    
    assert mf.with_df is not None
    filename = 'qmc_'+filename
    
    nalpha, nbeta = mc.nelecas
    nmo = mc.mo_coeff.shape[-1]
    
    # Extract ci expansion in the form of a tuple: (coeff, occ_a, occ_b).
    # Note the tol param which will return wavefunction elements with abs(ci) > tol.
    ci, occa, occb = zip(*fci.addons.large_ci(mc.ci, mc.ncas, (nalpha, nbeta), tol=tol, return_strs=False))  
    ci = np.array(ci)
    occa = np.array(occa)
    occb = np.array(occb)
    
    # Sort the determinants by the magnitude of their weight:
    ixs = np.argsort(np.abs(ci))[::-1]
    ci = ci[ixs]
    occa = occa[ixs]
    occb = occb[ixs]
    
    # Truncate to N number of SDs:
    if ndets is not None and isinstance(ndets,int):
        ci = ci[:ndets]
        occa = occa[:ndets]
        occb = occb[:ndets]          
               
    # Reinsert the frozen core as the AFQMC simulation is not run using an active space
    core = [i for i in range(mc.ncore)]
    occa = [np.array(core + [o + mc.ncore for o in oa]) for oa in occa]
    occb = [np.array(core + [o + mc.ncore for o in ob]) for ob in occb]
    
    # Generate Hamiltonian and store the CASSCF wave function h5 file
    scf_data = load_from_pyscf_chk_mol(chkfile, 'scf')
    scf_data['mo_coeff'] = mc.mo_coeff
    scf_data['mo_occ'] = mc.mo_occ
    hcore = mf.get_hcore()
    scf_data['hcore'] = hcore
    scf_data['mol'] = cell
    
    nao = hcore.shape[-1]
    Lpq = []
    for LpqR, LpqI, sign in mf.with_df.sr_loop(compact=False):
        Lpq.append(LpqR) 
    scf_data['df_ints'] = np.concatenate(Lpq)
    
    write_hamil_mol(scf_data, filename + ".h5", 1e-5, verbose=verbose, real_chol=False, df=True)
    ci = np.array(ci, dtype=np.complex128)
    uhf = True # UHF always true for CI expansions.
    write_qmcpack_wfn(filename+".h5", (ci, occa, occb), uhf, cell.nelec, nmo)
    
    options = {
        "Wavefunction": {
            "rediag": rediag,
        },
        "execute": {
            "timestep": 0.005,
            "blocks": blocks,
            "steps": 5,
            "substeps": 5,
            "nWalkers": nWalkers
        }
    }
    write_xml_input(filename+".xml", filename+".h5",filename+".h5",
                    options=options, rng_seed=7)   
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename))  
    
def gen_gamma_las_trial(filename, cell, mf, mc, chkfile, tol=0.005, verbose=False, 
                        ndets=None, blocks=10000, nWalkers=20, rediag=False, timestep=0.005, \
                        steps=5, substeps=5):
    '''Assuming density fitting integral is used'''
    
    assert mf.with_df is not None
    filename = 'qmc_'+filename
    
    nalpha, nbeta = mc.nelecas
    nmo = mc.mo_coeff.shape[-1]
    ci_cas, nelec = las_addons.las2cas_civec(mc)
    
    # Extract ci expansion in the form of a tuple: (coeff, occ_a, occ_b).
    # Note the tol param which will return wavefunction elements with abs(ci) > tol.
    ci, occa, occb = zip(*fci.addons.large_ci(ci_cas[0], mc.ncas, nelec, tol=tol, return_strs=False))  
    ci = np.array(ci)
    occa = np.array(occa)
    occb = np.array(occb)
    
    nalpha, nbeta = mc.nelecas
    nmo = mc.mo_coeff.shape[-1]
    
    # Sort the determinants by the magnitude of their weight:
    ixs = np.argsort(np.abs(ci))[::-1]
    ci = ci[ixs]
    occa = occa[ixs]
    occb = occb[ixs]
    
    # Truncate to N number of SDs:
    if ndets is not None and isinstance(ndets,int):
        ci = ci[:ndets]
        occa = occa[:ndets]
        occb = occb[:ndets]          
               
    # Reinsert the frozen core as the AFQMC simulation is not run using an active space
    core = [i for i in range(mc.ncore)]
    occa = [np.array(core + [o + mc.ncore for o in oa]) for oa in occa]
    occb = [np.array(core + [o + mc.ncore for o in ob]) for ob in occb]
    
    # Generate Hamiltonian and store the CASSCF wave function h5 file
    scf_data = load_from_pyscf_chk_mol(chkfile, 'scf')
    scf_data['mo_coeff'] = mc.mo_coeff
    scf_data['mo_occ'] = mc.mo_occ
    hcore = mf.get_hcore()
    scf_data['hcore'] = hcore
    scf_data['mol'] = cell
    
    nao = hcore.shape[-1]
    Lpq = []
    for LpqR, LpqI, sign in mf.with_df.sr_loop(compact=False):
        Lpq.append(LpqR) 
    scf_data['df_ints'] = np.concatenate(Lpq)
    
    write_hamil_mol(scf_data, filename + ".h5", 1e-5, verbose=verbose, real_chol=False, df=True)
    ci = np.array(ci, dtype=np.complex128)
    uhf = True # UHF always true for CI expansions.
    write_qmcpack_wfn(filename+".h5", (ci, occa, occb), uhf, cell.nelec, nmo)
    
    options = {
        "Wavefunction": {
            "rediag": rediag,
        },
        "execute": {
            "timestep": 0.005,
            "blocks": blocks,
            "steps": 5,
            "substeps": 5,
            "nWalkers": nWalkers
        }
    }
    write_xml_input(filename+".xml", filename+".h5",filename+".h5",
                    options=options, rng_seed=7)   
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename))
    
def dump_scf(mol, chkfile, e_tot, mo_energy, mo_coeff, mo_occ,
             overwrite_mol=True):
    '''save temporary results'''
    if h5py.is_hdf5(chkfile) and not overwrite_mol:
        with h5py.File(chkfile, 'a') as fh5:
            if 'mol' not in fh5:
                fh5['mol'] = mol.dumps()
    else:
        save_mol(mol, chkfile)

    scf_dic = {'e_tot'    : e_tot,
               'mo_energy': mo_energy,
               'mo_occ'   : mo_occ,
               'mo_coeff' : mo_coeff}
    save(chkfile, 'mcscf', scf_dic)
       
def gen_SHCI_trial(filename, chkfile, fcidump_file='FCIDUMP',ndets=None, dice_file='dets.bin', blocks=50000, nWalkers=20, chol=1.e-5, rediag=False, verbose=False ):
        
    filename = 'qmc_'+filename  
    
    # Read FCIDUMP
    f = open("{0}".format(fcidump_file),"r").readlines()
    tmp = f[0].split()
    nmo = int(tmp[1].split("=")[1])
    nelec = int(tmp[2].split("=")[1])
    ms2 = int(tmp[3].split("=")[1].split(",")[0])
    nalpha = (nelec + ms2) // 2
    nbeta = (nelec - ms2) // 2
    
    # Extract CI expansion in SHCI wfn
    ci, occa, occb = extract_dice.read_dice_wavefunction(dice_file)
    mf = fcidump.to_scf(fcidump_file)
        
    # Sort the CI expansion
    ix = np.argsort(np.abs(ci))[::-1]
    ci = ci[ix]
    occa = occa[ix]
    occb = occb[ix]

    # Truncate to N number of SDs:
    if ndets is not None and isinstance(ndets,int):
        ci = ci[:ndets]
        occa = occa[:ndets]
        occb = occb[:ndets]          
    
    # Generate Hamiltonian and store the CASSCF wave function h5 file
    scf_data = load_from_pyscf_chk_mol(chkfile, 'mcscf')
    write_hamil_mol(scf_data, filename+".h5", chol, verbose=verbose, real_chol=False)
    
    ci = np.array(ci, dtype=np.complex128)
    uhf = True # UHF always true for CI expansions.
    write_qmcpack_wfn(filename+".h5", (ci, occa, occb), uhf, (nalpha, nbeta), nmo)
    
    # Generate QMCPack input
    options = {
        "Wavefunction": {
            "rediag": rediag,
        },
        "execute": {
            "timestep": 0.005,
            "blocks": blocks,
            "steps": 5,
            "substeps": 5,
            "nWalkers": nWalkers
        }
    }
    write_xml_input(filename+".xml", filename+".h5",filename+".h5",
                    options=options, rng_seed=7)
    os.system("sed -i -e 's/\"qmc\"/\"{0}\"/g' {0}.xml".format(filename)) 