#!/usr/bin/env python
# coding: utf-8

import os, sys, h5py
sys.path.append("/root/pyscf")
sys.path.append("/root")
sys.path.append("/root/AFQMC")

import numpy as np
from pyscf import gto, scf, mcscf, mrpt, fci, lib, tools, lo
from pyscf.mcscf import avas
from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF
from mrh.my_pyscf.mcscf import addons
from pyscf.tools import molden
from utils import tool as afqmc_tools
from utils import tool_magmom
from pyscf.lib import logger

# Ref: PD-conformer https://www.rsc.org/suppdata/d2/cp/d2cp04335a/d2cp04335a1.pdf
mol = gto.Mole()
mol.charge = 0
mol.spin = 0
mol.atom = """
C 1.275754 -1.235450 0.698770
C 1.275754 -1.235450 -0.698770
C 0.276643 -1.918543 1.397540
C 0.276643 -1.918543 -1.397540
C -0.722468 -2.601635 0.698770
C -0.722468 -2.601635 -0.698770
H 2.050091 -0.706035 1.240335
H 2.050091 -0.706035 -1.240335
H 0.276643 -1.918543 2.480670
H 0.276643 -1.918543 -2.480670
H -1.496806 -3.131050 1.240335
H -1.496806 -3.131050 -1.240335
C -2.133751 2.490380 0.698770
C -2.133751 2.490380 -0.698770
C -1.134640 3.173473 1.397540
C -1.134640 3.173473 -1.397540
C -0.135529 3.856565 0.698770
C -0.135529 3.856565 -0.698770
H -2.908088 1.960965 1.240335
H -2.908088 1.960965 -1.240335
H -1.134640 3.173473 2.480670
H -1.134640 3.173473 -2.480670
H 0.638809 4.385980 1.240335
H 0.638809 4.385980 -1.240335
"""
mol.basis = 'STO-6G'
mol.verbose = 5
mol.spin = 0
mol.build()

mf = scf.RHF(mol)
mf.chkfile = 'chkfile'
dm = mf.from_chk(mf.chkfile)
mf.kernel(dm)
molden.from_mo(mol, 'dibenzene_hf.molden', mf.mo_coeff)

## Construct guess orbitals using PiOS
## Ref: 10.1021/acs.jctc.8b01196
from pyscf.mcscf.PiOS import MakePiOS
PiAtoms_benzene1 = [1,2,3,4,5,6]
PiAtoms_benzene2 = [13,14,15,16,17,18]
N_Core, N_Actorb, N_Virt, nelec, coeff = MakePiOS(mol, mf, PiAtoms_benzene1)
as1 = coeff[:,N_Core:(N_Core+N_Actorb)]
N_Core, N_Actorb, N_Virt, nelec, coeff = MakePiOS(mol, mf, PiAtoms_benzene2)
as2 = coeff[:,N_Core:(N_Core+N_Actorb)]
mo_PiOS = mf.mo_coeff.copy()
mo_PiOS[:,36:42] = as1
mo_PiOS[:,42:48] = as2
molden.from_mo(mol, 'dibenzene_PiOS.molden', mo_PiOS)


## Run a CAS(12,12) for each fragment
mycas = mcscf.CASSCF(mf, 12, 12)
mycas.chkfile = 'chkfile'
mycas.natorb = True
mo = lib.chkfile.load(mycas.chkfile, 'mcscf/mo_coeff')
e_tot, e_cas, ci, mo_coeff, mo_energy = mycas.kernel(mo)[:5]
molden.from_mo(mol, 'dibenzene_cas.molden', mycas.mo_coeff)

## Run a LAS(12,12) for each fragment
mylas = LASSCF(mf, [6,6], [6,6], spin_sub=(1,1))
mylas.chkfile = mf.chkfile
# mylas.verbose = 0
frag_atom_list = [list(range(12)), list(range(12,24))]
mo = lib.chkfile.load(mylas.chkfile, 'las/mo_coeff')
e_tot, e_las, ci, mo_coeff, mo_energy, h2eff_sub, veff = mylas.kernel(mo)
mo1 = mo_coeff
with h5py.File(mylas.chkfile, 'r+') as f:
    if 'las' in f.keys():
        del f['las/mo_coeff']
    f['las/mo_coeff'] = mo_coeff
molden.from_mo(mol, 'dibenzene_las.molden', mylas.mo_coeff)

cas=[60,60]
## Generate AFQMC inputs
afqmc_tools.gen_rhf_trial('rhf',mf.chkfile,blocks=25000, nWalkers=20,cas=cas)
os.system("sed -i -e '/\"substeps\">*/a \ \ \ \ <parameter name=\"checkpoint\">1</parameter>' qmc_rhf.xml")
os.system("sed -i -e '/\"checkpoint\">*/a \ \ \ \ <parameter name=\"hdf_write_file\">qmc_rhf</parameter>' qmc_rhf.xml")

tols = [0.05, 0.01, 0.005, 0.001, 0.00001,0.000001,0.0000001,0.00000001]
for tol in tols: 
    afqmc_tools.gen_cas_trial('cas_{0}'.format(tol), mol, mycas, tol=tol,cas=cas)
    afqmc_tools.gen_input('cas_{0}'.format(tol),blocks=25000, nWalkers=20) 
    os.system("sed -i -e '/\"substeps\">*/a \ \ \ \ <parameter name=\"checkpoint\">1</parameter>' qmc_cas_{0}.xml".format(tol))
    os.system("sed -i -e '/\"checkpoint\">*/a \ \ \ \ <parameter name=\"hdf_write_file\">qmc_cas_{0}</parameter>' qmc_cas_{0}.xml".format(tol))
    afqmc_tools.gen_las_trial('las_{0}'.format(tol), mol, mylas, tol=tol,cas=cas)
    afqmc_tools.gen_input('las_{0}'.format(tol),blocks=25000, nWalkers=20)  
    os.system("sed -i -e '/\"substeps\">*/a \ \ \ \ <parameter name=\"checkpoint\">1</parameter>' qmc_las_{0}.xml".format(tol))
    os.system("sed -i -e '/\"checkpoint\">*/a \ \ \ \ <parameter name=\"hdf_write_file\">qmc_las_{0}</parameter>' qmc_las_{0}.xml".format(tol))
    
    