# Two options for what data you can ask your sampler to output at every step: just the raw theta values, or the raw thetas and also the results of GPR inference at a set of pre-specified points using those thetas.
import time

from gprcommon import tt

class SaveThetas:
    outputs = 1
    
    def __call__(self, theta, pi, K_inv_y, varphi):
        return (theta,)

class KernelInference:
    outputs = 2
    
    def __init__(self, kernel, inference_point=None):
        if inference_point == None:
            inference_point = tt([[0] * kernel.base_K.problem.d])
        self.inference_point = inference_point
        self.K = kernel.K

    def __call__(self, theta, pi, K_inv_y, varphi):
        inference_value = self.K(theta, K_inv_y, xs=self.inference_point, noise=False)
        return theta, inference_value

class TimeProcessing:
    def __init__(self, base):
        self.base = base
        self.base_epoch = time.time() # This does _not_ guarantee that times returned will start from zero!  It just keeps the orders of magnitude about right.
        self.outputs = base.outputs + 1

    def __call__(self, *args):
        return tt([[time.time() - self.base_epoch]]), *self.base(*args)