from gprcommon import *
from kernelframework import *
from chebyshevkernel import *
from proposal import *
from preconditioning import *
from processing import *
from convergence import *

import warnings

def log_mh(logprob_old, logprob_new):
    upper_bounds = torch.clamp(logprob_new - logprob_old, max=0).exp()
    return upper_bounds

def markov_sample(problem, kernel, use_pseudofermion, proposal, h, initial_theta, batches, n_steps, n_jump, step_data_processor, make_linear_preconditioner=None, new_preconditioner_every=-1, jitter_jump=False, show_progress=True, recovery_thetas=None, precondition_contour=False):
    """Carries out sampling using a specific integrator.
    
    Parameters:
    problem: An instance of a Problem class, containing the underlying data.  See gprcommon.py.
    kernel: An instance of a Kernel class, specifying the GPR kernel and its hyperparameters but not how they are to be sampled.  See kernelframework.py.
    use_pseudofermion: Use the pseudofermion method rather than the determinant method.
    proposal: A class that proposes new Markov chain steps and computes Metropolis-Hastings log-probabilities.  See proposal.py.
    h: A step size.
    initial_theta: A one-dimensional tensor giving the initial values of the kernel hyperparameters.
    batches: The number of sampling chains to run simultaneously, all starting from the same initial_theta.
    n_steps: The number of samples to take from each chain.
    n_jump: The number of proposal steps to take per Metropolis-Hastings accept/reject check (and hence per sample returned).
    step_data_processor: A class that receives the raw data at every Metropolis-Hastings check and decides what observables should be stored and returned.  See processing.py.
    make_linear_preconditioner: A function that, given a value of theta, returns a preconditioner for the kernel matrix with those hyperparameters.  See preconditioning.py.
    new_preconditioner_every: How many M-H checks to make per new preconditioner.  Set to -1 to get a new preconditioner every step, not every leap.
    jitter_jump: Instead of taking n_jump steps per Metropolis-Hastings check, take a random number of steps between 1 and n_jump.
    show_progress: Print a progress bar.
    recovery_thetas: Pass in the complete array of thetas returned from a previous run and the function will attempt to continue sampling from those chains rather than from the location given in initial_theta.  NOTE: If this is not None, then the value of initial_theta means nothing."""
    if recovery_thetas is None:
        assert initial_theta.ndim == 1
        theta = initial_theta.unsqueeze(0).repeat(batches, 1)
    else:
        theta = tt(recovery_thetas[:, -1, :])
    pi = tt([0] * batches)
    K_inv_y = kernel.K_inv(theta, problem.y)

    step_data = [[] for _ in range(step_data_processor.outputs)]
    itercounts = []
    acceptances = []


    logprob = None
    varphi = None
    
    linear_preconditioner = None

    if not use_pseudofermion:
        proposal.V.set_varphi(None)

    for i in (tqdm(range(n_steps)) if show_progress else range(n_steps)):
        
        if use_pseudofermion:
            varphi = proposal.V.sample_varphi(theta, preconditioner=linear_preconditioner if precondition_contour else None)

        if proposal.requires_pi:
            pi = torch.einsum("ij,bj->bi", proposal.sqrtinv_M, tt(random.normal(size=theta.shape)))

        logprob = proposal.logprob(theta, pi, K_inv_y).unsqueeze(1)

        new_theta, new_pi, new_K_inv_y = tuple(map(lambda t: t.clone().detach(), (theta, pi, K_inv_y)))

        itercount_total = 0
        jump_size = random.integers(1, n_jump, endpoint=True) if jitter_jump else n_jump

        if new_preconditioner_every > 0 and i % new_preconditioner_every == 0:
            linear_preconditioner = make_linear_preconditioner(kernel, new_theta) if make_linear_preconditioner is not None else None

        for j in range(jump_size):
            if new_preconditioner_every == -1:
                linear_preconditioner = make_linear_preconditioner(kernel, new_theta) if make_linear_preconditioner is not None else None
            new_theta, new_pi, new_K_inv_y, itercount = proposal.step(h, new_theta, new_pi, new_K_inv_y, linear_preconditioner, final_step=(j == jump_size - 1))
            new_theta = new_theta.detach()
            new_pi = new_pi.detach()
            new_K_inv_y = new_K_inv_y.detach()
            itercount_total += itercount

        new_logprob = proposal.logprob(new_theta, new_pi, new_K_inv_y).unsqueeze(1)
        acceptance_prob = log_mh(logprob, new_logprob)
        acceptance = tt(random.uniform(size=acceptance_prob.shape)) < acceptance_prob

        theta, pi, K_inv_y, logprob = tuple(map(lambda t: torch.where(acceptance, *t), ((new_theta, theta), (new_pi, pi), (new_K_inv_y, K_inv_y), (new_logprob, logprob))))

        processed_data = step_data_processor(theta, pi, K_inv_y, varphi)
        for l, s in zip(step_data, processed_data):
            l.append(cpu(s).numpy())

        itercounts.append(itercount_total)
        acceptances.append(cpu(acceptance_prob).numpy())

    return (*tuple(map(lambda s: np.array(s).transpose(1, 0, 2), step_data + [acceptances])), # Transpose to get batches x timesteps x values
           itercounts)