import copy
import math
import time
import torch
import random
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from getdata import Dset
from scipy.stats import norm


###############################
###############################
###############################
###############################
###############################


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

###############################
###############################
###############################
###############################
###############################


def sample_correlated_gaussian(rho=0.5, dim=20, batch_size=128, to_cuda=False, cubic=False):
    # Generate samples from a correlated Gaussian distribution.
    mean = [0, 0]
    cov = [[1.0, rho], [rho, 1.0]]
    x, y = np.random.multivariate_normal(mean, cov, batch_size * dim).T

    x = x.reshape(-1, dim)
    y = y.reshape(-1, dim)

    if cubic:
        y = y ** 3

    if to_cuda:
        x = torch.from_numpy(x).float().cuda()
        y = torch.from_numpy(y).float().cuda()
    return x, y


def rho_to_mi(rho, dim):
    result = -dim / 2 * np.log(1 - rho ** 2)
    return result


def mi_to_rho(mi, dim):
    result = np.sqrt(1 - np.exp(-2 * mi / dim))
    return result

###############################
###############################
###############################
###############################
###############################


class FF(nn.Module):

    def __init__(self, dim_input, dim_hidden, dim_output, num_layers,
                 activation='relu', dropout_rate=0, layer_norm=False,
                 residual_connection=False):
        super(FF, self).__init__()
        assert (not residual_connection) or (dim_hidden == dim_input)
        self.residual_connection = residual_connection

        self.stack = nn.ModuleList()
        for l in range(num_layers):
            layer = []

            if layer_norm:
                layer.append(nn.LayerNorm(dim_input if l == 0 else dim_hidden))

            layer.append(nn.Linear(dim_input if l == 0 else dim_hidden,
                                   dim_hidden))
            layer.append({'tanh': nn.Tanh(), 'relu': nn.ReLU()}[activation])
            layer.append(nn.Dropout(dropout_rate))

            self.stack.append(nn.Sequential(*layer))

        self.out = nn.Linear(dim_input if num_layers < 1 else dim_hidden,
                             dim_output)

    def forward(self, x):
        for layer in self.stack:
            x = x + layer(x) if self.residual_connection else layer(x)
        return self.out(x)


class MultiGaussKernelEE(nn.Module):
    def __init__(self,
                 device,
                 number_of_samples,
                 hidden_size,
                 average='weighted',
                 cov_diagonal='var',
                 cov_off_diagonal='var',
                 ):

        self.K, self.d = number_of_samples, hidden_size
        super(MultiGaussKernelEE, self).__init__()
        self.device = device

        self.logC = torch.tensor([-self.d / 2 * np.log(2 * np.pi)]).to(
            self.device)

        # self.means = nn.Parameter(torch.rand(number_of_samples, hidden_size), requires_grad=True).to(
        #     self.device)
        self.means = nn.Parameter(torch.rand(number_of_samples, hidden_size).to(
            self.device), requires_grad=True)

        if cov_diagonal == 'const':
            diag = torch.ones((1, 1, self.d))
        elif cov_diagonal == 'var':
            diag = torch.ones((1, self.K, self.d))
        else:
            assert False, f'Invalid cov_diagonal: {cov_diagonal}'
        self.diag = nn.Parameter(diag.to(self.device))

        if cov_off_diagonal == 'var':
            tri = torch.zeros((1, self.K, self.d, self.d))
            self.tri = nn.Parameter(tri.to(self.device))
        elif cov_off_diagonal == 'zero':
            self.tri = None
        else:
            assert False, f'Invalid cov_off_diagonal: {cov_off_diagonal}'

        self.weigh = torch.ones(
            (1, self.K), requires_grad=False).to(self.device)
        if average == 'weighted':
            self.weigh = nn.Parameter(self.weigh, requires_grad=True)
        else:
            assert average == 'fixed', f"Invalid average: {average}"

    def logpdf(self, x, y=None):
        assert len(
            x.shape) == 2 and x.shape[1] == self.d, 'x has to have shape [N, d]'
        x = x[:, None, :]
        w = torch.log_softmax(self.weigh, dim=1)
        y = x - self.means
        if self.tri is not None:
            y = y * self.diag + \
                torch.squeeze(torch.matmul(torch.tril(
                    self.tri, diagonal=-1), y[:, :, :, None]), 3)
        else:
            y = y * self.diag
        y = torch.sum(y ** 2, dim=2)

        y = -y / 2 + torch.sum(torch.log(torch.abs(self.diag)), dim=2) + w

        y = torch.logsumexp(y, dim=-1)

        return self.logC + y

    def learning_loss(self, x_samples, y=None):
        return -self.forward(x_samples)

    def update_parameters(self, kernel_dict):
        tri = []
        means = []
        weigh = []
        diag = []
        for key, value in kernel_dict.items():  # detach and clone
            tri.append(copy.deepcopy(value.tri.detach().clone()))
            means.append(copy.deepcopy(value.means.detach().clone()))
            weigh.append(copy.deepcopy(value.weigh.detach().clone()))
            diag.append(copy.deepcopy(value.diag.detach().clone()))

        self.tri = nn.Parameter(torch.cat(tri, dim=1).to(self.device),
                                requires_grad=True)
        self.means = nn.Parameter(
            torch.cat(means, dim=0).to(self.device), requires_grad=True)
        self.weigh = nn.Parameter(
            torch.cat(weigh, dim=-1).to(self.device), requires_grad=True)
        self.diag = nn.Parameter(
            torch.cat(diag, dim=1).to(self.device), requires_grad=True)

    def pdf(self, x):
        return torch.exp(self.logpdf(x))

    def forward(self, x, y=None):
        y = torch.abs(-self.logpdf(x))
        return torch.mean(y)


class MultiGaussKernelCondEE(nn.Module):

    def __init__(self, device,
                 number_of_samples,  # [K, d]
                 x_size, y_size,
                 layers=1,
                 ):
        super(MultiGaussKernelCondEE, self).__init__()
        self.K, self.d = number_of_samples, y_size

        self.device = device

        self.logC = torch.tensor(
            [-x_size / 2 * np.log(2 * np.pi)]).to(self.device)

        self.std = FF(self.d, self.d * 2, self.K, layers)
        self.weight = FF(self.d, self.d * 2, self.K, layers)
        self.mean_weight = FF(self.d, self.d * 2, self.K * x_size, layers)
        self.x_size = x_size

    def _get_mean(self, y):
        means = self.mean_weight(y).reshape(
            (-1, self.K, self.x_size))  # [N, K, d]
        return means

    def _get_std(self, y):
        std = self.std(y).exp()
        return std

    def get_weight(self, y):
        w = torch.log_softmax(self.weight(y), dim=-1)  # [N, K]
        return w

    def logpdf(self, x, y):  # H(X|Y)
        x = x[:, None, :]  # [N, 1, d]

        w = torch.log_softmax(self.weight(y), dim=-1)  # [N, K]
        std = self.std(y).exp()  # [N, K]
        mu = self._get_mean(y)  # [1, K, d]

        y = x - mu  # [N, K, d]
        y = std ** 2 * torch.sum(y ** 2, dim=2)  # [N, K]

        y = -y / 2 + self.x_size * torch.log(torch.abs(std)) + w
        y = torch.logsumexp(y, dim=-1)
        return self.logC + y

    def pdf(self, x, y):
        return torch.exp(self.logpdf(x, y))  # P(x|y)

    def forward(self, x, y):
        z = -self.logpdf(x, y)
        return torch.mean(z)


class MIKernelEstimator(nn.Module):
    # object that estimate the Mutual Information
    def __init__(self, device, number_of_samples, x_size, y_size):
        super(MIKernelEstimator, self).__init__()
        self.count = 0
        self.count_learning = 0

        # compute the kernel that estimates H(X)
        self.kernel_1 = MultiGaussKernelEE(device, number_of_samples, x_size)

        # compute the kernel that estimates H(X|Y)
        self.kernel_conditional = MultiGaussKernelCondEE(
            device, number_of_samples, x_size, y_size)

    def get_kernel_conditional_pdf(self, x_samples, y_samples):
        return self.kernel_conditional.pdf(x_samples, y_samples)

    def forward(self, x_samples, y_samples):
        hz_1 = self.kernel_1(x_samples)  # H(X)
        hz_g1 = self.kernel_conditional(x_samples, y_samples)  # H(X|Y)
        self.count += 1
        return torch.abs(hz_1 - hz_g1)  # I(X;Y) = H(X) - H(X|Y)

    def learning_loss(self, x_samples, y_samples):
        hz_1 = self.kernel_1(x_samples)  # H(X)
        hz_g1 = self.kernel_conditional(x_samples, y_samples)  # H(X|Y)
        self.count_learning += 1
        return hz_1 + hz_g1


def train_KNIFE(trainDS, valDS, device, sim_conf, get_best):
    """Trains a KNIFE model using trainDS as training set, valDS as validation
    set, using the device (cuda/cpu) indicated. The configuration of the setup
    is in sim_conf. If get_best=True, then the model which minimizes the val
    loss is also returned. """

    epochs = sim_conf['nepochs']
    modes_number = sim_conf['modes_number']

    # define model
    model = MIKernelEstimator(device, modes_number,
                              1, trainDS.dset.shape[1]).to(device)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=sim_conf['l_rate'],
                                 weight_decay=sim_conf['wd'])

    loader = torch.utils.data.DataLoader(trainDS, shuffle=True,
                                         batch_size=sim_conf['batch_size'],
                                         pin_memory=True, num_workers=0)

    test_loader = torch.utils.data.DataLoader(valDS,
                                              batch_size=sim_conf['batch_size'],
                                              shuffle=False,
                                              pin_memory=True, num_workers=0)

    pbar = tqdm(range(epochs))
    train_loss_out = np.zeros(epochs)
    test_loss_out = np.zeros(epochs)

    for cnt in pbar:
        train_loss = 0
        model.train()
        for xi, yi in loader:
            if yi.shape[0] > 1:
                xi = xi.to(device)
                yi = yi.to(device)
                optimizer.zero_grad()

                model_loss = model.learning_loss(yi, xi)
                train_loss += model_loss.item()

                model_loss.backward()
                optimizer.step()

        test_loss = 0
        with torch.no_grad():
            model.eval()
            for (X, y) in test_loader:
                X = X.to(device)
                y = y.to(device)

                lossval = model.learning_loss(y, X)
                test_loss += lossval.item()

        pbar.set_description('Training KNIFE: ' +
                             'Avg Train Loss: {0:.4f}, '.format(train_loss/len(loader)) +
                             'Avg Val Loss: {0:.4f}'.format(test_loss/len(test_loader)))

        train_loss_out[cnt] = train_loss/len(loader)
        test_loss_out[cnt] = test_loss/len(loader)

        if np.isnan(test_loss):
            plt.figure()
            plt.plot(train_loss_out)
            plt.plot(test_loss_out)
            raise Exception('Problem')

        if get_best:
            if cnt == 0:
                best_ind = 0
                best_loss = test_loss_out[cnt]
                best_pars = copy.deepcopy(model.state_dict())
            else:
                if best_loss > test_loss_out[cnt]:
                    best_ind = cnt
                    best_loss = test_loss_out[cnt]
                    best_pars = copy.deepcopy(model.state_dict())

    if get_best:
        model.load_state_dict(best_pars)
        return model, train_loss_out, test_loss_out, best_ind
    else:
        return model, train_loss_out, test_loss_out, 0


def restore_KNIFEnet(device, K_res, K_path, in_dim):
    """Restores a KNIFE model """

    Knet = MIKernelEstimator(
        device, K_res['conf']['modes_number'], 1, in_dim).to(device)
    Knet.load_state_dict(torch.load(K_path, map_location=device))
    return Knet


def mixCDF(vals, mean, std, probs):

    if len(mean.shape) == 2:
        return np.sum(norm.cdf(vals, loc=mean, scale=std) * probs, axis=1)
    elif len(mean.shape) == 1:
        return np.sum(norm.cdf(vals, loc=mean, scale=std) * probs)


def computePg_KNIFEold(model, test_x, fx, eps, notion, sim_type, musig=0, sigmay=0):
    # def computePg_KNIFE(Xdata, dcenter, model, eps, device, probs, notion='abs', sim_type='outY', const=0):
    """For a given set of inputs, it computes the probability of being epsilon
    good around a certain value.
    model: KNIFE model
    test_x: input samples to the regressor
    fx: values of the regressor f(x) = y
    eps: value of epsilon. 
    If notion='rel' and sim_type= 'outY' eps is the relative percentage as usual.
    If notion='abs' and sim_type= 'outY' then eps is the standard value (normalized by the variance of y)
    If notion ='abs' and sim_type= 'absE' then eps is the epsilon value transformed
    by the scaler of the erro.
    If notion='rel' and sim_type= 'absE' then sim eps is eps
    notion: rel or abs
    sim_type: outY or absE

    """
    with torch.no_grad():
        mean = model.kernel_conditional._get_mean(
            test_x.float()).squeeze().cpu().numpy()
        std = 1/model.kernel_conditional._get_std(test_x.float()).cpu().numpy()
        weight = model.kernel_conditional.get_weight(test_x.float())
        probs = torch.softmax(weight, dim=1).cpu().numpy()

    if sim_type == 'outY':
        if notion == 'abs':
            Pg = mixCDF(fx+eps, mean, std, probs) \
                - mixCDF(fx-eps, mean, std, probs)
        elif notion == 'rel':
            Delta = eps*abs(fx + musig)
            Pg = mixCDF(fx+Delta, mean, std, probs) \
                - mixCDF(fx-Delta, mean, std, probs)
    elif sim_type == 'absE':
        if notion == 'abs':
            Pg = mixCDF(eps, mean, std, probs)
        elif notion == 'rel':  # the epsilon value of is different
            Pg = np.zeros(eps.size)
            for cnt, val in enumerate(eps):
                Pg[cnt] = mixCDF(val, mean[cnt], std[cnt], probs[cnt])

    return Pg

    # Delta = eps*abs(dcenter + const)
    # indmax = np.argmin(abs(Frange_test - (dcenter + Delta)), axis=0)
    # indmin = np.argmin(abs(Frange_test - (dcenter - Delta)), axis=0)


def computePg_KNIFE(model, test_x, fx, eps, notion, sim_type, device, Pg_conf):
    with torch.no_grad():
        mean = model.kernel_conditional._get_mean(
            test_x.float()).squeeze().cpu().numpy()
        std = 1/model.kernel_conditional._get_std(test_x.float()).cpu().numpy()
        weight = model.kernel_conditional.get_weight(test_x.float())
        probs = torch.softmax(weight, dim=1).cpu().numpy()

    if sim_type == 'outY':
        if notion == 'abs':
            Pg = mixCDF(fx+eps, mean, std, probs) \
                - mixCDF(fx-eps, mean, std, probs)
        elif notion == 'rel':
            if 'musig' not in Pg_conf.keys():
                raise Exception(
                    'For outY/rel eps combination, musig has to be passed in Pg_Conf')
            Delta = eps*abs(fx + Pg_conf['musig'])

            Pg = mixCDF(fx+Delta, mean, std, probs) \
                - mixCDF(fx-Delta, mean, std, probs)

    elif sim_type == 'absE':
        if 'scaler_err' not in Pg_conf.keys():
            raise Exception(
                'The absE notion requires the scaler of the error dataset in Pg_conf')

        if notion == 'abs':
            eps_sc = Pg_conf['scaler_err'].transform(np.array([eps], ndmin=2))[
                0][0]  # scaled epsilon
            Pg = mixCDF(eps_sc, mean, std, probs)

        elif notion == 'rel':  # the epsilon value of is different
            if 'musig' not in Pg_conf.keys():
                raise Exception(
                    'The relative notion requires musig to be passed in Pg_conf')
            eps_sc = Pg_conf['scaler_err'].transform(
                np.abs(fx+Pg_conf['musig'])*eps)

            Pg = mixCDF(eps_sc, mean, std, probs)
    elif sim_type == 'absErel':
        if 'scaler_err' not in Pg_conf.keys():
            raise Exception(
                'The absErel notion requires the scaler of the error dataset in Pg_conf')

        if notion == 'rel':
            eps_sc = Pg_conf['scaler_err'].transform(
                np.array([eps], ndmin=2))[0][0]
            Pg = mixCDF(eps_sc, mean, std, probs)
        elif notion == 'abs':
            raise Exception(
                'The absErel statistic cannot compute the abs error')

    return Pg


def sampleKNIFE(model, test_x, nsamps, **kwargs):
    """test_x are the input samples where to compute the normal samples

    """
    # mean = model.kernel_conditional._get_mean(test_x.float()).squeeze(2)
    # std = 1/model.kernel_conditional._get_std(test_x.float())
    # weight = model.kernel_conditional.get_weight(test_x.float())
    # if weight.ndim==1:
    #     probs = torch.softmax(weight, dim=0)
    # elif weight.ndim==2:
    #     probs = torch.softmax(weight, dim=1)
    # else:
    #     raise Exception('Expected 1D or 2D tensor of weights')

    if 'requires_grad' in kwargs.keys():
        with_statement = torch.enable_grad() if kwargs['requires_grad'] else torch.no_grad()
    with with_statement:
        mean = model.kernel_conditional._get_mean(test_x.float()).squeeze(2)
        std = 1/model.kernel_conditional._get_std(test_x.float())
        weight = model.kernel_conditional.get_weight(test_x.float())
        if weight.ndim == 1:
            probs = torch.softmax(weight, dim=0)
        elif weight.ndim == 2:
            probs = torch.softmax(weight, dim=1)
        else:
            raise Exception('Expected 1D or 2D tensor of weights')

    # index of the gaussians chosen for each test_x
    norm_idx = torch.multinomial(probs, nsamps, replacement=True)

    if test_x.ndim == 1:
        std = std.unsqueeze(0)
        norm_idx = norm_idx.unsqueeze(0)
        mu_choice = torch.zeros((1, nsamps), device=mean.device)
        std_choice = torch.zeros((1, nsamps), device=mean.device)
    else:
        mu_choice = torch.zeros((test_x.shape[0], nsamps), device=mean.device)
        std_choice = torch.zeros((test_x.shape[0], nsamps), device=mean.device)

    for cnt in range(nsamps):
        mu_choice[:, cnt] = mean[np.arange(mean.shape[0]), norm_idx[:, cnt]]
        std_choice[:, cnt] = std[np.arange(std.shape[0]), norm_idx[:, cnt]]

    out = torch.normal(mu_choice, std_choice)
    # print(test_x)
    return out


###############################
###############################
###############################
###############################
###############################


def main():

    set_seed(1)

    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sample_dim = 20  # from data
    batch_size = 256  # not useful if data is not synthetic
    # very useful to decide how many modes to use for the estimation
    modes_number = batch_size
    learning_rate = 1e-3  # very useful
    mi_value = 2.0  # 4.0, 6.0, 8.0, 10.0 # not useful for the estimation if t only for the synthetic data generation
    training_steps = 5  # 5000 # number of epochs for the training
    cubic = True  # only useful for synthetic data generation

    ###############################
    ###############################
    ###############################

    model = MIKernelEstimator(device, modes_number,
                              sample_dim, sample_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), learning_rate)

    ###############################
    ###############################
    ###############################

    progress_bar = tqdm(range(training_steps),
                        'Training Loop', position=0, leave=True)
    mi_est_values = []
    rho = mi_to_rho(mi_value, sample_dim)
    mean_history = []
    std_history = []
    cond_pdf_history = []
    weight_history = []
    model_loss_history = []

    for step in progress_bar:
        batch_x, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size=batch_size,
                                                      to_cuda=torch.cuda.is_available(), cubic=cubic)

        batch_x = torch.tensor(batch_x).float().to(device)
        batch_y = torch.tensor(batch_y).float().to(device)
        model.eval()
        mi_est_values.append(model(batch_x, batch_y).item())

        model.train()

        model_loss = model.learning_loss(batch_x, batch_y)
        model_loss_history.append(model_loss.item())

        optimizer.zero_grad()
        model_loss.backward()
        optimizer.step()

        cond_pdf = model.get_kernel_conditional_pdf(
            x_samples=batch_x, y_samples=batch_y)

        mean = model.kernel_conditional._get_mean(batch_y)
        std = model.kernel_conditional._get_std(batch_y)
        weight = model.kernel_conditional.get_weight(batch_y)

        mean_history.append(mean)
        std_history.append(std)
        cond_pdf_history.append(cond_pdf)
        weight_history.append(weight)

        # print(f'cond_pdf: {cond_pdf}')
        # print(f"cond_pdf.shape: {cond_pdf.shape}")
        # print(f"cond_pdf.sum: {cond_pdf.sum()}")

        progress_bar.set_description(
            f"Training Loop: step: {step}, batch_x.shape: {batch_x.shape}, batch_y.shape: {batch_y.shape}, model_loss: {model_loss.item()}"
            # f"Training Loop: step: {step}, mi_est_values: {mi_est_values[-1]}"
            # f"Training Loop: step: {step}, mean_shape: {mean.shape}, std_shape: {std.shape}"
            # f"Training Loop: step: {step}, cond_pdf_shape: {cond_pdf.shape}"
        )

        del batch_x, batch_y
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    plt.figure()
    plt.plot(model_loss_history)
    ###############################

    # samples generation

    _, batch_y = sample_correlated_gaussian(rho, dim=sample_dim, batch_size=10000,
                                            to_cuda=torch.cuda.is_available(), cubic=cubic)
    batch_y = torch.tensor(batch_y).float().to(device)

    mean = model.kernel_conditional._get_mean(batch_y)
    std = model.kernel_conditional._get_std(batch_y)
    weight = model.kernel_conditional.get_weight(batch_y)

    print(f"mean_shape: {mean.shape}")
    print(f"std_shape: {std.shape}")
    print(f"weight_shape: {weight.shape}")

    # weight distribution has 1 element for each sample
    weight_distribution = torch.exp(weight)
    print(f"weight_distribution_shape: {weight_distribution.shape}")

    idx_ndarray = np.arange(0, weight_distribution.shape[1])
    print(f"idx_ndarray_shape: {idx_ndarray.shape}")
    print(f"idx_ndarray: {idx_ndarray}")

    # samples = []

    for i in range(weight_distribution.shape[0]):
        choice_idx_mode = np.random.choice(
            idx_ndarray, p=weight_distribution[i].detach().cpu().numpy())
    #     print(f"choice_idx_mode: {choice_idx_mode}")
    #     choice_mu = mean[i][choice_idx_mode]
    #     choice_std = std[i][choice_idx_mode]

    #     # create a normal distribution with the chosen mu and std
    #     normal_dist = torch.distributions.Normal(loc=choice_mu, scale=choice_std)
    #     # sample from the normal distribution
    #     sample = normal_dist.sample()
    #     print(f"sample: {sample}")
    #     samples.append(sample)

    # samples_tensors = torch.stack(samples)
    # return samples_tensors

    # print(f"mean_shape: {mean_history[-1].shape}")
    # print(f"weight_shape: {weight_history[-1].shape}")
    # print(f"std_shape: {std_history[-1].shape}")

    # print(torch.exp(weight_history[-1]))
    # print(f"cond_pdf_shape: {cond_pdf_history[-1].shape}")

    # assert that each element in cond_pdf_history[-1] is positive
    # assert (cond_pdf_history[-1] >= 0).all(), "cond_pdf_history has negative values"

    # print(f"mean_history[0][0][0].shape: {mean_history[0][0][0].shape}")
    # print(f"std_history[0][0][0].shape: {std_history[0][0][0].shape}")

    # normal_dist = torch.distributions.Normal(
    #     loc=mean_history[0][0], scale=std_history[0][0])

    # for mean in mean_history:
    #     print(f"Mean: {mean}")


if __name__ == '__main__':
    main()
