import jax

import jax.numpy as jnp
import netket as nk
nk.config.netket_experimental_fft_autocorrelation = True
import matplotlib.pyplot as plt
import numpy as np

Lx = 8
Ly = 8
J = 1.0
hx = 0.3
hz = 0.
folder = f'N={3*Lx*Ly}/hx={hx:.2f}_hz={hz:.2f}_'

import qsl
lattice = qsl.lattice.Torus(1.0, Lx, Ly)
N = lattice.N

hi = nk.hilbert.Spin(1/2, N)

inf = 25.0
ma = qsl.models.JMF_dense(
    jastrow_init=qsl.models.rvb_init_W(infinity=inf/N, lattice=lattice, stddev=5*hx/N, restricted_hilbert=False), 
    mf_init=qsl.models.rvb_init_phi(infinity=inf/N, lattice=lattice, restricted_hilbert=False),
    param_dtype=jnp.complex128
    )
rule = qsl.rules.MixedRuleRVB(lattice=lattice, probs=[0.75,0.25/2,0.25/2])
sa = nk.sampler.MetropolisSampler(hi, rule, n_chains=1_024, sweep_size=N, dtype=np.int8)
vs = nk.vqs.MCState(sa, ma, n_samples=20_480, n_discard_per_chain=2 ) #, chunk_size=32)


TC = qsl.operators.ToricCode(hi, lattice, J=J, hx=hx, hz=hz)

logvmc = qsl.logging.StateJson(output_prefix=folder+f'VMC', save_params=True)

optimizer = nk.optimizer.Sgd(1e-3)
preconditioner = nk.optimizer.SR(diag_shift=1e-4, qgt=nk.optimizer.qgt.QGTJacobianPyTree, holomorphic=True)
gs = nk.driver.VMC(
    TC, 
    optimizer,
    variational_state=vs,
    preconditioner=preconditioner,
)



with open(folder+f'_infos.md', 'w') as file:
    file.write( 'Graph : '+repr(lattice)+'\n' )
    file.write( 'State : '+repr(vs)+'\n' )
    file.write( 'Machine : '+repr(ma)+'\n' )
    file.write( 'Sampler : '+repr(sa)+'\n' )
    file.write( 'Hamiltonian : '+repr(TC)+'\n' )
    file.write( 'VMC : '+repr(gs)[:-1]+'\n' )
    file.write( '  optimizer : '+repr(optimizer)+'\n' )
    file.write( '  SR : '+repr(preconditioner)+')\n' )


gs.run(
    1_000, 
    out=[logvmc, ], 
    callback=[
        nk.callbacks.InvalidLossStopping(patience=2), 
        nk.callbacks.ConvergenceStopping(monitor='variance', target=1e-4*(N/24)**2),
        ], 
    show_progress=True
)
