# The core implementations of RWM, HMC, and the leapfrog and implicit GL1 integrators.

from math import sqrt

import numpy as np
from scipy import linalg
import torch

from gprcommon import *
from anderson import *

class RWMProposal:
    requires_pi = False

    def __init__(self, V):
        self.V = V

    def step(self, h, theta, _pi, _K_inv_y, _linear_preconditioner=None, final_step=False):
        return theta + h * tt(random.normal(size=theta.shape)), _pi, _K_inv_y, 1

    def logprob(self, theta, _pi, _K_inv_y):
        return -self.V(theta, K_inv_y=None) # Don't use the stored value of K_inv_y -- it's inaccurate

class HMCProposal:
    requires_pi = True

    def __init__(self, V, integrator, M=None, **integrator_args):
        self.V = V
        M = np.eye(V.kernel.base_K.data_required) if M is None else M
        self.sqrtinv_M = tt(linalg.sqrtm(np.linalg.inv(M)))
        self.M = tt(M)
        self.integrator = integrator(self.V, self.M, **integrator_args)

    def step(self, h, theta, pi, K_inv_y, linear_preconditioner=None, final_step=False):
        return self.integrator(h, theta, pi, K_inv_y, linear_preconditioner, final_step=final_step)

    def logprob(self, theta, pi, K_inv_y):
        return -self.V(theta, K_inv_y=K_inv_y) - (1/2) * torch.einsum("bi,ij,bj->b", pi, self.M, pi)

class LeapfrogIntegrator:
    def __init__(self, V, mass_matrix):
        self.V = V
        self.kernel = V.kernel
        self.y = V.problem.y
        self.mass_matrix = mass_matrix

    def __call__(self, h, theta0, pi0, K_inv_y0, linear_preconditioner, final_step=False):
        pi = pi0 - (h / 2) * self.V.grad(theta0, K_inv_y=K_inv_y0)
        theta = theta0 + h * torch.einsum("ij,bj->bi", self.mass_matrix, pi)
        with torch.no_grad():
            K_inv_y, info = self.kernel.K_inv(theta, self.y, preconditioner=linear_preconditioner, return_info=True)
        pi = pi - (h / 2) * self.V.grad(theta, K_inv_y=K_inv_y)
        return theta, pi, K_inv_y, info["niter"]

class GL1Integrator:
    def _dq_dp(self, K_inv_y, q, p):
        """Calculates dq/dt and dp/dt given the Hamiltonian H(q, p) = V(q) + 1/2 p' M p."""
        dHdp = torch.einsum("ij,bj->bi", self.mass_matrix, p)
        dHdq = self.V.grad(q, K_inv_y=K_inv_y)
        return dHdp, -dHdq

    def __init__(self, V, mass_matrix, m_max=10, max_iters=1000, damping_coeff=0):
        self.V = V
        self.kernel = V.kernel
        self.y = V.problem.y
        self.rtol = V.kernel.rtol
        self.mass_matrix = mass_matrix
        self.m_max = m_max
        self.max_iters = max_iters
        self.damping_coeff = damping_coeff
        self.new_coeff = 1 - damping_coeff
        
    def __call__(self, h, theta0, pi0, K_inv_y0, linear_preconditioner, final_step=False):
        if linear_preconditioner is None:
            linear_preconditioner = lambda x: x

        split_shapes = [K_inv_y0.shape[1]] + [theta0.shape[1], pi0.shape[1]]

        def damp(old, new):
            return self.damping_coeff * old + self.new_coeff * new

        @anderson_pack(split_shapes)
        def fixed_point_problem(K_inv_y, q, p):
            k1 = self._dq_dp(K_inv_y.detach(), (theta0 + q) / 2, (pi0 + p) / 2)
            theta = theta0 + h * k1[0]
            pi = pi0 + h * k1[1]
            theta = damp(q, theta)
            pi = damp(p, pi)
            new_K_inv_y = K_inv_y + linear_preconditioner(self.y.unsqueeze(0) - self.kernel.K((theta + theta0) / 2, K_inv_y))
            return new_K_inv_y, theta, pi

        initial_x = torch.concatenate([K_inv_y0] + [theta0, pi0], dim=1).type(tensor)
        iteration_result, iters = anderson_fpiter(fixed_point_problem, initial_x, tol=self.rtol, m_max=self.m_max, max_iters=self.max_iters)
        K_inv_y, theta, pi = torch.split(iteration_result, split_shapes, dim=1)

        if final_step:
            with torch.no_grad():
                K_inv_y, info = self.kernel.K_inv(theta, self.y, preconditioner=linear_preconditioner, return_info=True)

        return theta, pi, K_inv_y, iters