# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import copy
import torch
import random
import argparse
import numpy as np
from tqdm import tqdm
from scipy.stats import norm
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset


def reset_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class NegativeLogLikelihoodLoss(torch.nn.Module):
    """
    Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles
    Equation (1)
    (https://arxiv.org/abs/1612.01474)
    """

    def __init__(self):
        super(NegativeLogLikelihoodLoss, self).__init__()

    def forward(self, yhat, y):
        mean = yhat[:, 0].view(-1, 1)
        variance = yhat[:, 1].view(-1, 1)

        # make variance positive and stable (footnote 2)
        variance2 = variance.exp().add(1).log().add(0.001)

        return (variance2.log().div(2) + (y - mean).pow(2).div(variance2.mul(2))).mean()


class Perceptron(torch.nn.Module):
    def __init__(self,
                 n_inputs,
                 n_outputs,
                 n_layers,
                 n_hiddens,
                 alpha,
                 dropout):
        super(Perceptron, self).__init__()

        layers = []

        if n_layers == 0:
            layers.append(torch.nn.Linear(n_inputs, n_outputs))
        else:
            layers.append(torch.nn.Linear(n_inputs, n_hiddens))
            layers.append(torch.nn.ReLU())
            layers.append(torch.nn.Dropout(dropout))

            for layer in range(n_layers - 1):
                layers.append(torch.nn.Linear(n_hiddens, n_hiddens))
                layers.append(torch.nn.ReLU())
                layers.append(torch.nn.Dropout(dropout))

            layers.append(torch.nn.Linear(n_hiddens, n_outputs))

        self.perceptron = torch.nn.Sequential(*layers)
        self.loss_function = None

    def loss(self, x, y):
        return self.loss_function(self.perceptron(x), y)


class Ensemble(torch.nn.Module):
    def __init__(self,
                 n_ens,
                 n_inputs,
                 n_outputs,
                 n_layers,
                 n_hiddens,
                 alpha):
        super(Ensemble, self).__init__()

        # choose network
        extra_inputs = 0
        extra_outputs = 0
        effective_dropout = 0

        BaseModel = ConditionalGaussian
        extra_outputs = 1

        self.alpha = alpha
        self.learners = torch.nn.ModuleList()

        for _ in range(n_ens):
            self.learners.append(BaseModel(n_inputs=n_inputs + extra_inputs,
                                           n_outputs=n_outputs + extra_outputs,
                                           n_layers=n_layers,
                                           n_hiddens=n_hiddens,
                                           alpha=alpha,
                                           dropout=effective_dropout))

    def predict(self, x, **kwargs):
        preds_mean = torch.zeros(len(self.learners), x.size(0), 1)
        preds_low = torch.zeros(len(self.learners), x.size(0), 1)
        preds_high = torch.zeros(len(self.learners), x.size(0), 1)
        std = torch.zeros(len(self.learners), x.size(0), 1)
        std2 = torch.zeros(len(self.learners), x.size(0), 1)

        for l, learner in enumerate(self.learners):
            preds_mean[l], preds_low[l], preds_high[l], std[l], std2[l] = learner.predict(
                x, **kwargs)

        m = len(self.learners)

        threshold = norm.ppf(self.alpha / 2)

        preds_mean = preds_mean.mean(0)
        preds_std = std.mean(0)
        preds_std2 = std2.mean(0)
        preds_low = preds_low.mean(0) - threshold * preds_low.std(0, m > 1)
        preds_high = preds_high.mean(0) + threshold * preds_high.std(0, m > 1)

        return preds_mean, preds_low, preds_high, preds_std, preds_std2

    def loss(self, x, y):
        loss = 0
        for learner in self.learners:
            loss += learner.loss(x, y)

        return loss


class ConditionalGaussian(Perceptron):
    def __init__(self, **kwargs):
        super(ConditionalGaussian, self).__init__(**kwargs)
        self.loss_function = NegativeLogLikelihoodLoss()
        self.alpha = kwargs["alpha"]

    def predict(self, x, **kwargs):
        if 'requires_grad' in kwargs.keys():
            if kwargs['requires_grad']:
                predictions = self.perceptron(x)
            else:
                predictions = self.perceptron(x).detach()
        else:
            predictions = self.perceptron(x).detach()
        mean = predictions[:, 0].view(-1, 1)
        var = predictions[:, 1].view(-1, 1)
        var2 = var.exp().add(1).log().add(1e-6)
        interval = var2.sqrt().mul(norm.ppf(self.alpha / 2))
        std = var.sqrt()
        std2 = var2.sqrt()
        return mean, mean - interval, mean + interval, std, std2


def train_conditional_gaussian(trainDS, valDS, device, sim_conf, get_best):
    epochs = sim_conf['nepochs']
    n_ens = sim_conf['n_ens']
    alpha = sim_conf['alpha']
    n_hidden_layers = sim_conf['n_hidden_layers']
    n_hidden_units = sim_conf['n_hidden_units']
    lr = sim_conf['l_rate']
    wd = sim_conf['wd']

    # x_tr = trainDS.dset.detach().cpu().numpy()
    # y_tr = trainDS.out.detach().cpu().numpy()

    # x_va = valDS.dset.detach().cpu().numpy()
    # y_va = valDS.out.detach().cpu().numpy()

    # s_tr_x = StandardScaler().fit(x_tr)
    # s_tr_y = StandardScaler().fit(y_tr)

    # x_tr = torch.Tensor(s_tr_x.transform(x_tr))
    # x_va = torch.Tensor(s_tr_x.transform(x_va))

    # y_tr = torch.Tensor(s_tr_y.transform(y_tr))
    # y_va = torch.Tensor(s_tr_y.transform(y_va))

    # loader_tr = DataLoader(TensorDataset(x_tr, y_tr),
    #                        shuffle=True,
    #                        batch_size=sim_conf['batch_size'])

    x_tr = trainDS.dset
    y_tr = trainDS.out

    loader = torch.utils.data.DataLoader(trainDS, shuffle=True,
                                         batch_size=sim_conf['batch_size'],
                                         pin_memory=True, num_workers=0)

    test_loader = torch.utils.data.DataLoader(valDS,
                                              batch_size=sim_conf['batch_size'],
                                              shuffle=False,
                                              pin_memory=True, num_workers=0)

    network_name = "ConditionalGaussian"

    model = Ensemble(n_ens,
                     x_tr.size(1), y_tr.size(1),
                     n_hidden_layers, n_hidden_units,
                     alpha)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=lr,
                                 weight_decay=wd)

    pbar = tqdm(range(epochs))
    train_loss_out = np.zeros(epochs)
    test_loss_out = np.zeros(epochs)
    for cnt in pbar:
        train_loss = 0
        model.train()
        for xi, yi in loader:
            if yi.shape[0] > 1:
                xi = xi.to(device)
                yi = yi.to(device)
                optimizer.zero_grad()

                model_loss = model.loss(xi, yi)
                train_loss += model_loss.item()

                model_loss.backward()
                optimizer.step()
        test_loss = 0
        with torch.no_grad():
            model.eval()
            for (X, y) in test_loader:
                X = X.to(device)
                y = y.to(device)

                lossval = model.loss(X, y)
                test_loss += lossval.item()
        pbar.set_description(f'Training {network_name}: ' +
                             'Avg Train Loss: {0:.4f}, '.format(train_loss/len(loader)) +
                             'Avg Val Loss: {0:.4f}'.format(test_loss/len(test_loader)))
        train_loss_out[cnt] = train_loss/len(loader)
        test_loss_out[cnt] = test_loss/len(loader)

        if get_best:
            if cnt == 0:
                best_ind = 0
                best_loss = test_loss_out[cnt]
                best_pars = copy.deepcopy(model.state_dict())
            else:
                if best_loss > test_loss_out[cnt]:
                    best_ind = cnt
                    best_loss = test_loss_out[cnt]
                    best_pars = copy.deepcopy(model.state_dict())

    if get_best:
        model.load_state_dict(best_pars)
        return model, train_loss_out, test_loss_out, best_ind
    else:
        return model, train_loss_out, test_loss_out, 0


def mixCDF(vals, mean, std):
    return norm.cdf(vals, loc=mean, scale=std)

# def computePg_conditional_gaussian(model, test_x, fx, eps, notion, sim_type, musig):
#     with torch.no_grad():
#         mean, _, _, std, std2 = model.predict(test_x)

#     if sim_type == 'outY':
#         if notion == 'abs':
#             Pg = mixCDF(fx+eps, mean, std) \
#                 - mixCDF(fx-eps, mean, std)
#         elif notion == 'rel':
#             Delta = eps*abs(fx + musig)
#             Pg = mixCDF(fx+Delta, mean, std) \
#                 - mixCDF(fx-Delta, mean, std)
#             # plt.plot(fx+Delta, 'k')
#             # plt.plot(fx-Delta, 'k')
#     elif sim_type == 'absE':
#         raise Exception('To be implemented')

#     return Pg


def computePg_conditional_gaussian(model, test_x, fx, eps, notion, sim_type, device, Pg_conf):
    with torch.no_grad():
        mean, _, _, _, std2 = model.predict(test_x.float())

    if sim_type == 'outY':
        if notion == 'abs':
            Pg = mixCDF(fx+eps, mean, std2) \
                - mixCDF(fx-eps, mean, std2)
        elif notion == 'rel':
            if 'musig' not in Pg_conf.keys():
                raise Exception(
                    'For outY/rel eps combination, musig has to be passed in Pg_Conf')
            Delta = eps*abs(fx + Pg_conf['musig'])

            Pg = mixCDF(fx+Delta, mean, std2) \
                - mixCDF(fx-Delta, mean, std2)

    elif sim_type == 'absE':
        if 'scaler_err' not in Pg_conf.keys():
            raise Exception(
                'The absE notion requires the scaler of the error dataset in Pg_conf')

        if notion == 'abs':
            eps_sc = Pg_conf['scaler_err'].transform(np.array([eps], ndmin=2))[
                0][0]  # scaled epsilon
            Pg = mixCDF(eps_sc, mean, std2)

        elif notion == 'rel':  # the epsilon value of is different
            if 'musig' not in Pg_conf.keys():
                raise Exception(
                    'The relative notion requires musig to be passed in Pg_conf')
            eps_sc = Pg_conf['scaler_err'].transform(
                np.abs(fx+Pg_conf['musig'])*eps)

            Pg = mixCDF(eps_sc, mean, std2)
    elif sim_type == 'absErel':
        if 'scaler_err' not in Pg_conf.keys():
            raise Exception(
                'The absErel notion requires the scaler of the error dataset in Pg_conf')

        if notion == 'rel':
            eps_sc = Pg_conf['scaler_err'].transform(
                np.array([eps], ndmin=2))[0][0]
            Pg = mixCDF(eps_sc, mean, std2)
        elif notion == 'abs':
            raise Exception(
                'The absErel statistic cannot compute the abs error')

    return Pg


def restore_cond_gaussian(device, C_res, C_path, in_dim):
    """Restores a KNIFE model """

    model = Ensemble(C_res['conf']['n_ens'],
                     in_dim, 1,
                     C_res['conf']['n_hidden_layers'],
                     C_res['conf']['n_hidden_units'],
                     C_res['conf']['alpha'])

    model.load_state_dict(torch.load(C_path, map_location=device))
    return model


def samplecond_gaussian(model, test_x, nsamps, **kwargs):
    """test_x are the input samples where to compute the normal samples

    """

    if 'requires_grad' in kwargs.keys():
        with_statement = torch.enable_grad(
        ) if kwargs['requires_grad'] else torch.no_grad()
    with with_statement:
        if test_x.ndim == 1:
            mean, _, _, _, std2 = model.predict(test_x.unsqueeze(0).float(), requires_grad=kwargs['requires_grad'])
        else:
            mean, _, _, _, std2 = model.predict(test_x.float(), requires_grad=kwargs['requires_grad'])
    if test_x.ndim == 1:
        out = torch.zeros(test_x.unsqueeze(
            0).shape[0], nsamps, device=test_x.device)
    else:
        out = torch.zeros(test_x.shape[0], nsamps, device=test_x.device)

    for cnt in range(nsamps):
        out[:, cnt] = torch.normal(mean, std2).flatten()

    return out


if __name__ == "__main__":
    train_conditional_gaussian()

    ########################################
    ########################################
    ########################################
    ########################################
    ########################################

    # parser = argparse.ArgumentParser()
    # parser.add_argument('--dataset', type=str, default="boston")
    # parser.add_argument('--seed', type=int, default=3)
    # parser.add_argument('--n_hidden_layers', type=int, default=1)
    # parser.add_argument('--n_epochs', type=int, default=2000)
    # parser.add_argument('--n_hidden_units', type=int, default=64)
    # parser.add_argument('--bs', type=int, default=64)
    # parser.add_argument('--lr', type=float, default=1e-3)
    # parser.add_argument('--wd', type=float, default=0)
    # parser.add_argument('--n_ens', type=int, default=5)
    # parser.add_argument('--alpha', type=float, default=0.05)
    # args = parser.parse_args()

    # reset_seeds(args.seed)

    # # load data
    # data = np.loadtxt("UCI_Datasets/{}.txt".format(args.dataset))
    # x_al = data[:, :-1]
    # y_al = data[:, -1].reshape(-1, 1)

    # x_tr, x_te, y_tr, y_te = train_test_split(
    #     x_al, y_al, test_size=0.1, random_state=args.seed)
    # x_tr, x_va, y_tr, y_va = train_test_split(
    #     x_tr, y_tr, test_size=0.2, random_state=args.seed)

    # s_tr_x = StandardScaler().fit(x_tr)
    # s_tr_y = StandardScaler().fit(y_tr)

    # x_tr = torch.Tensor(s_tr_x.transform(x_tr))
    # x_va = torch.Tensor(s_tr_x.transform(x_va))
    # x_te = torch.Tensor(s_tr_x.transform(x_te))

    # y_tr = torch.Tensor(s_tr_y.transform(y_tr))
    # y_va = torch.Tensor(s_tr_y.transform(y_va))
    # y_te = torch.Tensor(s_tr_y.transform(y_te))
    # y_al = torch.Tensor(s_tr_y.transform(y_al))

    # network_name = "ConditionalGaussian"

    # reset_seeds(args.seed)

    # network = Ensemble(args.n_ens,
    #                    x_tr.size(1), y_tr.size(1),
    #                    args.n_hidden_layers, args.n_hidden_units,
    #                    args.alpha)

    # loader_tr = DataLoader(TensorDataset(x_tr, y_tr),
    #                        shuffle=True,
    #                        batch_size=args.bs)

    # optimizer = torch.optim.Adam(network.parameters(),
    #                              lr=args.lr,
    #                              weight_decay=args.wd)

    # progress_bar = tqdm(range(args.n_epochs), desc=network_name)

    # for epoch in progress_bar:
    #     for (xi, yi) in loader_tr:
    #         optimizer.zero_grad()
    #         network.loss(xi, yi).backward()
    #         optimizer.step()
    #         progress_bar.set_description(
    #             f"{network_name} | {args.dataset} | {epoch} | {network.loss(xi, yi).item():.5f}")

    # # make predictions
    # p_mean_tr, p_low_tr, p_high_tr = network.predict(x_tr)
    # p_mean_va, p_low_va, p_high_va = network.predict(x_va)
    # p_mean_te, p_low_te, p_high_te = network.predict(x_te)

    # # final losses
    # mse_tr = network.loss(x_tr, y_tr)
    # mse_va = network.loss(x_va, y_va)
    # mse_te = network.loss(x_te, y_te)

    # # percentage of captured points
    # capture_tr = (p_low_tr.lt(y_tr) * y_tr.lt(p_high_tr)).float().mean()
    # capture_va = (p_low_va.lt(y_va) * y_va.lt(p_high_va)).float().mean()
    # capture_te = (p_low_te.lt(y_te) * y_te.lt(p_high_te)).float().mean()

    # # width of intervals
    # y_range = (y_al.max() - y_al.min())
    # width_tr = (p_high_tr - p_low_tr).abs().mean() / y_range
    # width_va = (p_high_va - p_low_va).abs().mean() / y_range
    # width_te = (p_high_te - p_low_te).abs().mean() / y_range

    # print("{:<22} | {:<26} | {:.5f} {:.5f} {:.5f} | {:.5f} {:.5f} {:.5f} | {:.5f} {:.5f} {:.5f} | {:<2} | {:<4} | {} | {}".format(
    #     network_name + "-" + str(args.n_ens), args.dataset,
    #     mse_tr, capture_tr, width_tr,
    #     mse_va, capture_va, width_va,
    #     mse_te, capture_te, width_te,
    #     args.seed,
    #     epoch,
    #     args.lr,
    #     args.wd))

    # print(f"network_name: {network_name} | n_ens: {args.n_ens} | dataset: {args.dataset} | mse_tr: {mse_tr} | capture_tr: {capture_tr} | width_tr: {width_tr} | mse_va: {mse_va} | capture_va: {capture_va} | width_va: {width_va} | mse_te: {mse_te} | capture_te: {capture_te} | width_te: {width_te} | seed: {args.seed} | epoch: {epoch} | lr: {args.lr} | wd: {args.wd}")
