#!/usr/bin/env python
# coding: utf-8

import os, sys
import numpy as np
from pyscf import gto, scf, mcscf, mrpt, fci, lib, tools
from pyscf.tools import molden
from h10_struct import structure as struct
from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF
from mrh.my_pyscf.mcscf import addons
from pyscf.tools import molden
from utils import tool as afqmc_tools
from utils import tool_magmom

Rs = np.arange(3.0,0.4,-0.1)
mo0 = None
mo1 = None

for i, R in enumerate(Rs): 

    ## RHF calculation
    mol = struct(natm=10, R=R, basis='sto6g', symmetry=False, verbose=0)
    mf = scf.RHF(mol)
    mf.max_cycle = 200
    mf.chkfile = "h10_rhf_{0:1.2f}.chk".format(R)
    mf.kernel()
    
    ## UHF calculation
    mol = struct(natm=10, R=R, basis='sto6g', symmetry=False, verbose=0)
    mf2 = scf.UHF(mol)
    mf2.max_cycle = 200
    mf2.chkfile = "h10_uhf_{0:1.2f}.chk".format(R)
    dm0 = tool_magmom.guess_magmom(mf2, magmom="H:1s:1 H:1s:-1 H:1s:1 H:1s:-1 H:1s:1 H:1s:-1 H:1s:1 H:1s:-1 H:1s:1 H:1s:-1", normalized=True)
    mf2.kernel(dm0)    
    
    ## Run a LAS(2,2) for each fragment
    mylas = LASSCF(mf, [2]*5, [2]*5, spin_sub=(1,1,1,1,1))
    mylas.max_cycle_macro = 500
    mylas.chkfile = "h10_rhf_{0:1.2f}.chk".format(R)    # LASSCF object doesn't really have a chkfile
    frag_atom_list = []
    for i in range(5):
        frag_atom_list.append(np.arange(2) + 2*i)
    if mo0 is None:
        mo = lib.chkfile.load('h10_rhf_3.00.chk', 'mcscf/mo_coeff')
        mo0 = mylas.localize_init_guess(frag_atom_list, mo)
    else:
        mo0 = mcscf.project_init_guess(mylas, mo0)
    e_tot, e_las, ci, mo_coeff, mo_energy, h2eff_sub, veff = mylas.kernel(mo0)
    mo0 = mo_coeff
    molden.from_mo(mol, "H10_LAS_{0:1.2f}.molden".format(R), mylas.mo_coeff)
    
    ## Run a CAS(8,8) for each fragment
    mycas = mcscf.CASSCF(mf, 8, 8)
    e_tot, e_cas, ci, mo_coeff, mo_energy = mycas.kernel()[:5]
    for i in range(5):
        frag_atom_list.append(np.arange(2) + 2*i)
    if mo1 is None:
        mo1 = mf.mo_coeff
    else:
        mo1 = mcscf.project_init_guess(mycas, mo1)
        
    e_fci, e_fci_cas, fci_ci, mo_coeff, mo_energy = mycas.kernel(mo1)
    mo1 = mo_coeff
    
    ## NEVPT2
    e_corr = mrpt.NEVPT(mycas).kernel()
    e_pt2 = mycas.e_tot + e_corr
    
    ## Run FCI
    fci = mcscf.CASSCF(mf, 10, 10)
    e_fci, e_fci_cas, fci_ci, fci_mo_coeff, fci_mo_energy = fci.kernel()[:5]
    
    ## Write QMCPack inputs
    blocks = 8000
    nWalkers =  20
    afqmc_tools.gen_las_trial('las_{0:1.2f}'.format(R), mol, mylas, tol=1e-8)
    afqmc_tools.gen_input('las_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_cas_trial('cas_{0:1.2f}'.format(R), mol, mycas, tol=1e-8)
    afqmc_tools.gen_input('cas_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)    
    afqmc_tools.gen_rhf_trial('rhf_{0:1.2f}'.format(R),mf.chkfile,blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_uhf_trial('uhf_{0:1.2f}'.format(R),mf2.chkfile,blocks=blocks,nWalkers=nWalkers)

    
