#!/usr/bin/env python
# coding: utf-8

import os, sys
import numpy as np
from pyscf import gto, scf, mcscf, mrpt, fci, lib, tools
from c2h4n4_struct import structure as struct
from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF
from mrh.my_pyscf.mcscf import addons
from pyscf.tools import molden
from utils import tool as afqmc_tools
from utils import tool_magmom
from pyscf.lib import logger

rnn0 = 1.2499144882526947
Rs = rnn0 + np.arange (1.5, -0.301, -0.1)
Rs = np.insert(Rs, 15, Rs[14]-0.05)
dm0 = None
dm1 = None
mo0 = None
mo1 = None
mo2 = None
mo3 = None

def stable_opt_internal(mf):
    log = logger.new_logger(mf)
    mo1, _, stable, _ = mf.stability(return_status=True)
    cyc = 0
    while (not stable and cyc < 20):
        log.note('Try to optimize orbitals until stable, attempt %d' % cyc)
        dm1 = mf.make_rdm1(mo1, mf.mo_occ)
        mf = mf.run(dm1)
        mo1, _, stable, _ = mf.stability(return_status=True)
        cyc += 1
    if not stable:
        log.note('Stability Opt failed after %d attempts' % cyc)
        
    return mf

for i, R in enumerate(Rs): 
    dnn = R - rnn0
    ## RHF calculation
    mol = struct(dnn, dnn, basis='6-31G', symmetry=False, verbose=0)
    mf = scf.RHF(mol)
    mf.max_cycle = 200
    mf.chkfile = "rhf_{0:1.2f}.chk".format(R)
    if dm1 is None:
        dm0 = mf.get_init_guess()  
    mf.kernel(dm0)  
    dm0 = mf.make_rdm1()
    
    ## UHF calculation
    mf2 = scf.UHF(mol).newton()
    mf2.max_cycle = 200
    mf2.chkfile = "uhf_{0:1.2f}.chk".format(R)
    if dm1 is None:
        dm_alpha, dm_beta = mf2.get_init_guess()
        dm_beta[:10,:10] = 0.4
        dm1 = (dm_alpha,dm_beta)
    mf2 = mf2.run(dm1) 
    mf2 = stable_opt_internal(mf2) 
    dm1 = mf2.make_rdm1()
    
    ## Run a LAS(8,8) for each fragment
    mylas8 = LASSCF(mf, [4,4], [4,4], spin_sub=(1,1))
    mylas8.chkfile = "rhf_{0:1.2f}.chk".format(R)    # LASSCF object doesn't really have a chkfile
    frag_atom_list = [[0,1,2],[7,8,9]]
    if mo0 is None:
        mo0 = lib.chkfile.load('rhf_1.25.chk', 'mcscf/mo_coeff')
        mo0 = mylas8.localize_init_guess(frag_atom_list, mo0)
    else:
        mo0 = mcscf.project_init_guess(mylas8, mo0)
    e_tot, e_las, ci, mo_coeff, mo_energy, h2eff_sub, veff = mylas8.kernel(mo0)
    mo0 = mo_coeff    

    ## Run a LAS(10,10) for each fragment
    mylas10 = LASSCF(mf, [4,2,4], [4,2,4], spin_sub=(1,1,1))
    mylas10.chkfile = "rhf_{0:1.2f}.chk".format(R)    # LASSCF object doesn't really have a chkfile
    frag_atom_list = [[0,1,2],[3,4,5,6],[7,8,9]]
    if mo1 is None:
        mo1 = mylas10.localize_init_guess(frag_atom_list, mo_coeff)
    else:
        mo1 = mcscf.project_init_guess(mylas10, mo1)
    e_tot, e_las, ci, mo_coeff, mo_energy, h2eff_sub, veff = mylas10.kernel(mo1)
    mo1 = mo_coeff
    
    ## Run a CAS(8,8) for each fragment
    mycas8 = mcscf.CASSCF(mf, 8, 8)
    e_tot, e_cas, ci, mo_coeff, mo_energy = mycas8.kernel()[:5]
    if mo2 is None:
        mo2 = mo_coeff
    else:
        mo2 = mcscf.project_init_guess(mycas8, mo2)
    e_tot, e_cas, ci, mo_coeff, mo_energy = mycas8.kernel(mo2)[:5]
    mo2 = mo_coeff
        
    ## Run a CAS(10,10) for each fragment
    mycas10 = mcscf.CASSCF(mf, 10, 10)
    if mo3 is None:
        mo3 = mo_coeff
    else:
        mo3 = mcscf.project_init_guess(mycas10, mo3)
    e_tot, e_cas, ci, mo_coeff, mo_energy = mycas10.kernel(mo3)[:5]
    mo3 = mo_coeff        
        
    print(R, mf.e_tot, mf2.e_tot, mylas8.e_tot, mylas10.e_tot, mycas8.e_tot, mycas10.e_tot)
    
    
    ## Write QMCPack inputs
    blocks = 12000
    nWalkers =  20
    afqmc_tools.gen_las_trial('las8_{0:1.2f}'.format(R), mol, mylas8, tol=1e-8)
    afqmc_tools.gen_input('las8_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_las_trial('las10_{0:1.2f}'.format(R), mol, mylas10, tol=1e-8)
    afqmc_tools.gen_input('las10_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_cas_trial('cas8_{0:1.2f}'.format(R), mol, mycas8, tol=1e-8)
    afqmc_tools.gen_input('cas8_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)    
    afqmc_tools.gen_cas_trial('cas10_{0:1.2f}'.format(R), mol, mycas10, tol=1e-8)
    afqmc_tools.gen_input('cas10_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)    
    afqmc_tools.gen_rhf_trial('rhf_{0:1.2f}'.format(R),mf.chkfile,blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_uhf_trial('uhf_{0:1.2f}'.format(R),mf2.chkfile,blocks=blocks,nWalkers=nWalkers)
