#!/usr/bin/env python
# coding: utf-8

import os, sys
import numpy as np
from pyscf import gto, scf, mcscf, mrpt, fci, lib, tools
from c2h6n4_struct import structure as struct
from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF
from mrh.my_pyscf.mcscf import addons
from pyscf.tools import molden
from utils import tool as afqmc_tools
from utils import tool_magmom
from pyscf.lib import logger

rnn0 = 1.23681571
Rs = rnn0 + np.arange (1.5, -0.301, -0.1)
Rs = np.insert(Rs, 15, Rs[14]-0.05)
dm0 = None
dm1 = None
mo0 = None
mo1 = None

def stable_opt_internal(mf):
    log = logger.new_logger(mf)
    mo1, _, stable, _ = mf.stability(return_status=True)
    cyc = 0
    while (not stable and cyc < 20):
        log.note('Try to optimize orbitals until stable, attempt %d' % cyc)
        dm1 = mf.make_rdm1(mo1, mf.mo_occ)
        mf = mf.run(dm1)
        mo1, _, stable, _ = mf.stability(return_status=True)
        cyc += 1
    if not stable:
        log.note('Stability Opt failed after %d attempts' % cyc)
        
    return mf

for i, R in enumerate(Rs): 
    dnn = R - rnn0
    ## RHF calculation
    mol = struct(dnn, dnn, basis='6-31G', symmetry=False, verbose=0)
    mf = scf.RHF(mol)
    mf.max_cycle = 200
    mf.chkfile = "rhf_{0:1.2f}.chk".format(R)
    if dm1 is None:
        dm0 = mf.get_init_guess()  
    mf.kernel(dm0)  
    dm0 = mf.make_rdm1()
    
    ## UHF calculation
    mf2 = scf.UHF(mol).newton()
    mf2.max_cycle = 200
    mf2.chkfile = "uhf_{0:1.2f}.chk".format(R)xcz
    if dm1 is None:
        dnn = 1.7 - rnn0
        mol2 = struct(dnn, dnn, basis='6-31G', symmetry=False, verbose=0)
        mf3 = scf.UHF(mol2).newton()
        dm_alpha, dm_beta = mf3.get_init_guess()
        dm_beta[:10,:10] = 0.25
        dm_alpha[:10,:10] = -0.25
        dm1 = (dm_alpha,dm_beta)
        mf3.run(dm1) 
        dm1 = mf3.make_rdm1()
    mf2 = mf2.run(dm1) 
    mf2 = stable_opt_internal(mf2) 
    dm1 = mf2.make_rdm1()
    
    ## Run a LAS(8,8) for each fragment
    mylas8 = LASSCF(mf, [4,4], [4,4], spin_sub=(1,1))
    mylas8.max_cycle_macro = 200
    mylas8.chkfile = "rhf_{0:1.2f}.chk".format(R)    # LASSCF object doesn't really have a chkfile
    frag_atom_list = (list(range(3)), list(range(9,12)))
    if mo0 is None:
        mo0 = mylas8.localize_init_guess(frag_atom_list)
    else:
        mo0 = mcscf.project_init_guess(mylas8, mo0)
    e_tot, e_las, ci, mo_coeff, mo_energy, h2eff_sub, veff = mylas8.kernel(mo0)
    mo0 = mo_coeff
    molden.from_mo(mol, "c2h6n4_las_{0:1.2f}.molden".format(R), mylas8.mo_coeff)

    ## Run a CAS(8,8) for each fragment
    mycas8 = mcscf.CASSCF(mf, 8, 8)
    e_tot, e_cas, ci, mo_coeff, mo_energy = mycas8.kernel()[:5]
    if mo1 is None:
        mo1 = mo_coeff
    else:
        mo1 = mcscf.project_init_guess(mycas8, mo1)
    e_tot, e_cas, ci, mo_coeff, mo_energy = mycas8.kernel(mo1)[:5]
    mo1 = mo_coeff
    molden.from_mo(mol, "c2h6n4_cas_{0:1.2f}.molden".format(R), mycas8.mo_coeff)
    
    ## Write QMCPack inputs
    blocks = 12000
    nWalkers =  20
    afqmc_tools.gen_las_trial('las8_{0:1.2f}'.format(R), mol, mylas8, tol=1e-8)
    afqmc_tools.gen_input('las8_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_cas_trial('cas8_{0:1.2f}'.format(R), mol, mycas8, tol=1e-8)
    afqmc_tools.gen_input('cas8_{0:1.2f}'.format(R),blocks=blocks,nWalkers=nWalkers)       
    afqmc_tools.gen_rhf_trial('rhf_{0:1.2f}'.format(R),mf.chkfile,blocks=blocks,nWalkers=nWalkers)
    afqmc_tools.gen_uhf_trial('uhf_{0:1.2f}'.format(R),mf2.chkfile,blocks=blocks,nWalkers=nWalkers)