# -*- coding: utf-8 -*-

import torch
from tqdm import tqdm
import numpy as np
from networks import DeepNetBN
import copy

class QuantileLoss(torch.nn.Module):
    def __init__(self):
        super(QuantileLoss, self).__init__()
        

    def forward(self, yhat, y, tau):
        '''Computes the pinball loss. It can be written as:
            Loss = (yhat-y) * (1{yhat-y}>0 - tau)
            '''
        diff = yhat - y
        mask = (diff.ge(0).float() - tau).detach()
        return (mask * diff).mean()
        
    
def augment(x, tau=None,device=None):
    """ Given the input data 'x' it computes and augmented input by adding an
    additional column to 'x' consisting of the quantiles that are to be 
    computed. If tau=None, then the median tau=0.5 is added. If not, the passed
    vector tau is added.
    
    Internally the value of tau is rescaled to the (-6,6) for computations."""
    
    if tau is None:
        tau = torch.zeros(x.size(0), 1).fill_(0.5).to(device)
    elif type(tau) == float:
        tau = torch.zeros(x.size(0), 1).fill_(tau).to(device)
    
    return torch.cat((x, (tau - 0.5) * 12), 1).to(device)

class IntervalLoss(torch.nn.Module):
    def __init__(self, alpha):
        super(IntervalLoss, self).__init__()
        self.alpha = alpha

    def forward(self, ic, y):
        """Computes the interval loss, ic[0] = lhat, ic[1] = uhat,
        y is the output of the regressor, alpha is the confidence level.
        
        Parameters:
        ic:
            (1{yhat-y}>0 - tau)
        """
        # alpha = (alpha-0.5)*12
        lhat = ic[:,0].unsqueeze(1)
        uhat = ic[:,1].unsqueeze(1)
        diff_lhaty = lhat-y
        diff_yuhat = y-uhat
        
        mask_lhaty = diff_lhaty.ge(0).float().detach()
        mask_yuhat = diff_yuhat.ge(0).float().detach()
        
        loss = (uhat - lhat) + 2/self.alpha * (diff_lhaty * mask_lhaty + \
            diff_yuhat * mask_yuhat)
        
        return loss.mean()

def trainQuantileNet(Dset, valDset, device, qnet_conf, get_best=False, bound_last=False):
    
    epochs = qnet_conf['nepochs']
    
    if qnet_conf['qtype'] == 'SQR':
        net = DeepNetBN(n_in=Dset.dset.shape[1]+1,
                            n_out=Dset.out.shape[1],
                            n_hlayers=qnet_conf['n_hlayers'],
                            n_inner_neurons=qnet_conf['n_inner_neurons'],
                            bound_last=bound_last)
        loss = QuantileLoss()
    elif qnet_conf['qtype'] == 'Interval':
        net = DeepNetBN(n_in=Dset.dset.shape[1],
                            n_out=2,
                            n_hlayers=qnet_conf['n_hlayers'],
                            n_inner_neurons=qnet_conf['n_inner_neurons'])
        loss = IntervalLoss(alpha=qnet_conf['alpha'])
    else:
        raise Exception('Unknown type of quantile network')
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), 
                                 lr=qnet_conf['l_rate'], 
                                 weight_decay=qnet_conf['wd'])

    loader = torch.utils.data.DataLoader(Dset, shuffle=True, 
                                         batch_size=qnet_conf['batch_size'], 
                                         pin_memory=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(valDset,
                                              batch_size=qnet_conf['batch_size'], 
                                              shuffle=False,
                                              pin_memory=True, num_workers=0)

    pbar = tqdm(range(epochs))

    train_loss_out = np.zeros(epochs)
    test_loss_out = np.zeros(epochs)
    
    
    for cnt in pbar:
        train_loss = 0
        net.train()
        for xi, yi in loader:
            if yi.shape[0] > 1:
                xi = xi.to(device)
                yi = yi.to(device)
                optimizer.zero_grad()
                if qnet_conf['qtype'] == 'SQR':
                    taus = torch.rand(xi.size(0), 1).to(device)

                    lossval = loss(net(augment(xi, taus,device)), yi, taus)
                elif qnet_conf['qtype'] == 'Interval':
                    lossval = loss(net(xi), yi)
                train_loss += lossval.item()
                lossval.backward()
                optimizer.step()

        test_loss = 0
        with torch.no_grad():
            net.eval()
            for (X, y) in test_loader:
                X = X.to(device)
                y = y.to(device)

                if qnet_conf['qtype'] == 'SQR':
                    taus = torch.rand(X.size(0), 1).to(device)

                    lossval = loss(net(augment(X, taus,device)), y, taus)
                elif qnet_conf['qtype'] == 'Interval':
                    lossval = loss(net(X), y)
                test_loss += lossval.item()

        pbar.set_description('Training {0}: '.format(qnet_conf['qtype'])+
                             'Avg Train Loss: {0:.4f}, '.format(train_loss/len(loader))+
                             'Avg Val Loss: {0:.4f}'.format(test_loss/len(test_loader)))
        train_loss_out[cnt] = train_loss/len(loader)
        test_loss_out[cnt] = test_loss/len(loader)

        if get_best:
            if cnt == 0:
                best_ind = 0
                best_loss = test_loss_out[cnt]
                best_pars = copy.deepcopy(net.state_dict())
            else:
                if best_loss > test_loss_out[cnt]:
                    best_ind = cnt
                    best_loss = test_loss_out[cnt]
                    best_pars = copy.deepcopy(net.state_dict())

    if get_best:
        net.load_state_dict(best_pars)
        return net, train_loss_out, test_loss_out, best_ind
    else:
        return net, train_loss_out, test_loss_out, 0


def find_cdf(Xdata, Qnet, device, quants):
    ''' Finds the CDF for the x values in testDS, from a quantile network Qnet
    the CDF is found for the quantiles in quants'''
    
    # First find the CDF for each input
    Frange_test = np.zeros((quants.size, Xdata.shape[0])) # output CDF for each input
    with torch.no_grad():
        Qnet.eval()
        for cnt in range(quants.size):
            Frange_test[cnt] = Qnet(augment(Xdata.float().to(device), float(quants[cnt]), device)).detach().cpu().numpy().ravel()

    return Frange_test

def find_IC(Xdata, Qnet, device, conf):
    
    quants = np.array([(1-conf)/2, 0.5+conf/2])
    return find_cdf(Xdata, Qnet, device, quants)
    
    
def find_Pg(Xdata, dcenter, Qnet, eps, device, probs, notion='abs', sim_type='outY', const=0):
    """For a given set of inputs, it computes the probability of being epsilon
    good around a certain value.

    Computes:
    --------------
    + If notion=='abs' and sim_type=='outY' the algorithm estimates:
    P( |Y-f(x)| < eps | X=x) = F_{Y|X=x} (dcenter+eps) - F_{Y|X=x} (dcenter-eps)
    where dcenter = f(x) is the output of the regressor and eps is a scalar.

    + If notion=='rel' and sim_type=='outY' the algorithm estimates:
    P( |Y-f(x)|/|f(x)| < eps | X=x) which is:
           when f(x)>0: F_{Y|X=x}((1+eps)f(x)) - F_{Y|X=x}((1-eps)f(x)) = F_{Y|X=x}((1+eps)*dcenter) -
                                                                                             F_{Y|X=x}((1-eps)*dcenter)
           when f(x)<0: F_{Y|X=x}((1-eps)f(x)) - F_{Y|X=x}((1+eps)f(x)) = F_{Y|X=x}((1-eps)*dcenter) -
                                                                                             F_{Y|X=x}((1+eps)*dcenter)

    + If notion=='abs' and sim_type=='absE', the quantile network is over E=|Y-f(x)|
    P( |Y-f(x)| < eps | X=x) = F_{E|X=x}(eps)

    + If notion=='rel' and sim_type=='absE':
    P( |Y-f(x)| < eps |f(x)| |X=x) = F_{E|X=x} (eps)

    in this case, the value of dcenter has to be set as dcenter=0 and the value passed to eps must be eps*abs(f(x))

    
    Parameters
    ----------
    - Xdata: (torch.tensor) the x values for which we compute the probability of being good for |X=x
    dcenter (np.ndarray): datacenter of shape (n_samples). In combination of eps it is used to compute the probability
    of being epsilon good
    - Qnet (QuantileLoss): (nn.Model) quantile regression model used to estimate the corresponding distribution function.
    - eps (nd.array or float): used to compute the probability of being eps-good, according to the case.
    - device: 'cuda' or 'cpu'
    - probs: (nd.array) vector of increasing values in (0,1) where to compute the CDF to estimate the probability of
    being espilon good.
    - notion: (str), if 'abs' the absolute epsilon good probability is computed; if 'rel' the relative notion is used.
    - sim_type (str): if 'outY' the samples and quantile function are obtained from the output Y|X=x. If 'absE' we are
    assumed to use the absolute error |Y-f(x)||X=x.


    Returns
    -------
    Pg: (nd.array) vector of eps-good probabilities, one for each row in Xdata
    Frange_test: (nd.array) matrix with the ranges of y corresponding to the quantiles in quants, each column corresponds to
    one range. This means that (Frange_test[:,i], quants) contains the CDF given Xdata[i,:] of the Qnet.
    """

    # First find the CDF for each input
    Frange_test = np.zeros((probs.size, Xdata.shape[0]))  # output CDF for each input
    with torch.no_grad(): #compute the CDF for each value of X
        Qnet.eval()
        for cnt in range(probs.size):
            Frange_test[cnt] = Qnet(augment(Xdata.float().to(device), float(probs[cnt]), device)).detach().cpu().numpy().ravel()

    if sim_type == 'outY': # extract Pg from the distribution of Y|X=x
        if notion == 'abs': # absolute epsilon good
            indmax = np.argmin(abs(Frange_test-(dcenter+eps)), axis=0)
            indmin = np.argmin(abs(Frange_test-(dcenter-eps)), axis=0)
            Psup = probs[indmax]  # F_{Y|X=x}(y+eps) probabilities obtained from the best approximation
            maxX = np.amax(Frange_test, axis=0)  # maximum value for Y|X=x
            Psup[dcenter + eps > maxX] = 1  # if the value of y+eps is greater than the range, set to 1 the probability

            Pinf = probs[indmin]  # F_{Y|X=x}(y-eps) probabilities obtained from the best approximation
            minX = np.amin(Frange_test, axis=0)  # minimum range for the variables
            Pinf[dcenter - eps < minX] = 0  # if the value of y-eps is below the range set the probability to 0

        elif notion == 'rel':
            Delta = eps*abs(dcenter + const)
            indmax = np.argmin(abs(Frange_test - (dcenter + Delta)), axis=0)
            indmin = np.argmin(abs(Frange_test - (dcenter - Delta)), axis=0)
            # indtmp1 = np.argmin(abs(Frange_test/dcenter-(1+eps)), axis=0)
            # indtmp2= np.argmin(abs(Frange_test/dcenter-(1-eps)), axis=0)
            # indmax = indtmp1.copy()
            # indmax[dcenter<0] = indtmp2[dcenter<0]
            # indmin = indtmp2.copy()
            # indmin[dcenter<0] = indtmp1[dcenter<0]

            Psup = probs[indmax]  # F_{Y|X=x}(y+eps) probabilities obtained from the best approximation
            Pinf = probs[indmin]  # F_{Y|X=x}(y-eps) probabilities obtained from the best approximation
            maxX = np.amax(Frange_test, axis=0)  # maximum value for Y|X=x
            minX = np.amin(Frange_test, axis=0)  # minimum range for the variables
            Psup[dcenter+Delta > maxX] = 1  # if the range is above the admissible, prob is 1
            Pinf[dcenter-Delta < minX] = 0  # if the range is below the admissible, prob is 0
        else:
            raise Exception('Unknown notion of Eps good, must be "abs" or "rel"')

        Pg = Psup - Pinf # compute as the difference of the distribution functions
        Pg[Pg < 0] = 0 # If due to crossing quantiles this gives negative or greater than 1, crop to the limits.
        Pg[Pg > 1] = 1

    elif sim_type == 'absE': # the conditional error |E| |X=x to compute Pg
        maxX = np.amax(Frange_test, axis=0)  # maximum value for E|X=x
        minX = np.amin(Frange_test, axis=0)  # minimum range for E|X=x
        if notion in ['abs', 'rel']:
            indmax = np.argmin(abs(Frange_test - eps), axis=0) # we need F_{E|X=x}(eps)
            # in the relative notion, the value passed in eps is scaler(eps * abs(f(x)))
        else:
            raise Exception('Unknown notion of Eps good, must be "abs" or "rel"')
        Pg = probs[indmax]  # F_{E|X=x}(e) probabilities obtained from the best approximation
        if eps.size == 1:
            Pg[eps[0] > maxX] = 1  # if the value of eps is too high or too low limit it to 1
            Pg[eps[0] < minX] = 0
        else:
            Pg[eps > maxX] = 1
            Pg[eps < minX] = 0

    else:
            raise Exception('Unknown notion of Epsilon good, options are "abs" or "outY"')

    return Pg, Frange_test

    
    # (model, test_x, fx, eps, notion, sim_type):
def find_Pg2(Qnet, Xdata, fx, eps, notion, sim_type, device, Pg_conf):
    # First find the CDF for each input
    probs = Pg_conf['probs'] 
    Frange_test = np.zeros((probs.size, Xdata.shape[0]))  # output CDF for each input
    with torch.no_grad(): #compute the CDF for each value of X
        Qnet.eval()
        probs_stack = torch.tensor(probs, dtype=torch.float32,device=device).repeat(Xdata.shape[0]).unsqueeze(1) # 
        X_stack = Xdata.repeat_interleave(repeats=probs.size, dim=0).float().to(device)
        # yhat_stack = yhat.repeat_interleave(repeats=nunif, dim=0).flatten() #interleave as the Y samplesy
        Frange_test = Qnet(augment(X_stack,probs_stack, device)).cpu().numpy().reshape((-1, probs.size)).T
        # for cnt in range(probs.size):
        #     Frange_test[cnt] = Qnet(augment(Xdata.float().to(device), float(probs[cnt]), device)).detach().cpu().numpy().ravel()

    if sim_type == 'outY': # extract Pg from the distribution of Y|X=x
        if notion == 'abs': # absolute epsilon good
            indmax = np.argmin(abs(Frange_test-(fx+eps)), axis=0)
            indmin = np.argmin(abs(Frange_test-(fx-eps)), axis=0)
            Psup = probs[indmax]  # F_{Y|X=x}(y+eps) probabilities obtained from the best approximation
            maxX = np.amax(Frange_test, axis=0)  # maximum value for Y|X=x
            Psup[fx + eps > maxX] = 1  # if the value of y+eps is greater than the range, set to 1 the probability

            Pinf = probs[indmin]  # F_{Y|X=x}(y-eps) probabilities obtained from the best approximation
            minX = np.amin(Frange_test, axis=0)  # minimum range for the variables
            Pinf[fx - eps < minX] = 0  # if the value of y-eps is below the range set the probability to 0
            
        elif notion == 'rel':
            if 'musig' not in Pg_conf.keys():
                raise Exception('For outY/rel eps combination, musig has to be passed in Pg_Conf')
                
            Delta = eps*abs(fx + Pg_conf['musig'])
            indmax = np.argmin(abs(Frange_test - (fx + Delta)), axis=0)
            indmin = np.argmin(abs(Frange_test - (fx - Delta)), axis=0)
            Psup = probs[indmax]  # F_{Y|X=x}(y+eps) probabilities obtained from the best approximation
            Pinf = probs[indmin]  # F_{Y|X=x}(y-eps) probabilities obtained from the best approximation
            maxX = np.amax(Frange_test, axis=0)  # maximum value for Y|X=x
            minX = np.amin(Frange_test, axis=0)  # minimum range for the variables
            Psup[fx+Delta > maxX] = 1  # if the range is above the admissible, prob is 1
            Pinf[fx-Delta < minX] = 0  # if the range is below the admissible, prob is 0

        Pg = Psup - Pinf # compute as the difference of the distribution functions
        Pg[Pg < 0] = 0 # If due to crossing quantiles this gives negative or greater than 1, crop to the limits.
        Pg[Pg > 1] = 1
    elif sim_type == 'absE':
        maxX = np.amax(Frange_test, axis=0)  # maximum value for E|X=x
        minX = np.amin(Frange_test, axis=0)  # minimum range for E|X=x
        if 'scaler_err' not in Pg_conf.keys():
            raise Exception('The absE notion requires the scaler of the error dataset in Pg_conf')
        

        if notion == 'abs':
            eps_sc = Pg_conf['scaler_err'].transform(np.array([eps], ndmin=2))[0][0] # scaled epsilon
            indmax = np.argmin(abs(Frange_test - eps_sc), axis=0) # we need F_{E|X=x}(eps)
            
        elif notion == 'rel':
            if 'musig' not in Pg_conf.keys():
                raise Exception('The relative notion requires musig to be passed in Pg_conf')
            eps_sc = Pg_conf['scaler_err'].transform(np.abs(fx.reshape(-1,1)+Pg_conf['musig'])*eps).flatten()
            # in the relative notion, the value passed in eps is scaler(eps * abs(f(x)))
            indmax = np.argmin(abs(Frange_test - eps_sc), axis=0) # we need F_{E|X=x}(eps)


        Pg = probs[indmax]  # F_{E|X=x}(e) probabilities obtained from the best approximation
        if eps_sc.size == 1:
            Pg[eps_sc > maxX] = 1  # if the value of eps is too high or too low limit it to 1
            Pg[eps_sc < minX] = 0
        else:
            Pg[eps_sc > maxX] = 1
            Pg[eps_sc < minX] = 0
            
    elif sim_type == 'absErel':
        if 'scaler_err' not in Pg_conf.keys():
            raise Exception('The absErel notion requires the scaler of the error dataset in Pg_conf')
 
        if notion == 'rel':
            eps_sc = Pg_conf['scaler_err'].transform(np.array([eps], ndmin=2))[0][0]
            maxX = np.amax(Frange_test, axis=0)  # maximum value for E|X=x
            minX = np.amin(Frange_test, axis=0)  # minimum range for E|X=x
            indmax = np.argmin(abs(Frange_test - eps_sc), axis=0) # we need F_{E|X=x}(eps)
            
            Pg = probs[indmax]  # F_{E|X=x}(e) probabilities obtained from the best approximation
            if eps_sc.size == 1:
                Pg[eps_sc > maxX] = 1  # if the value of eps is too high or too low limit it to 1
                Pg[eps_sc < minX] = 0
            else:
                Pg[eps_sc > maxX] = 1
                Pg[eps_sc < minX] = 0
        elif notion == 'abs':
            raise Exception('The absErel statistic cannot compute the abs error')
            
    
    return Pg, Frange_test