# -*- coding: utf-8 -*-


import torch
from torch.utils import data
from torch import nn
import copy
from getdata import Dset
from tqdm import tqdm
import numpy as np

def get_sim_device(in_string, verb=True):
    '''Returns a string which represents the simulation device to be used'''
    
    # Setup of device for simulation
    if in_string == 'auto':
        cuda = torch.cuda.is_available()
        device = torch.device("cuda" if cuda else "cpu")
    else:
        device = in_string
    if verb:
        print('-----\nUsing device {0}\n-----'.format(device))
    return device


class RegressionNet(nn.Module):
    ''' Basic neural network for regression.'''
    def __init__(self, n_in, n_out, n_hlayers, n_inner_neurons, dropout):
        super(RegressionNet, self).__init__() 
        
        layer_list = [nn.Linear(n_in,n_inner_neurons), nn.ReLU(), nn.Dropout(dropout)]
        
        for cnt in range(n_hlayers):
            layer_list += [nn.Linear(n_inner_neurons,n_inner_neurons),
                           nn.ReLU(),
                           nn.Dropout(dropout)]
                                     
        layer_list += [nn.Linear(n_inner_neurons,n_out)]
        
        self.layers = nn.ModuleList(layer_list)
                
    def forward(self, x):
        for cnt, layer in enumerate(self.layers):
            if cnt == 0:
                out = layer(x)
            else:
                out = layer(out)

        return out


class DeepNetBN(nn.Module):
    ''' Basic neural network for regression.'''

    def __init__(self, n_in, n_out, n_hlayers, n_inner_neurons, bound_last=False):
        super(DeepNetBN, self).__init__()

        layer_list = [nn.Linear(n_in, n_inner_neurons), nn.BatchNorm1d(n_inner_neurons), nn.ReLU()]

        for cnt in range(n_hlayers):
            layer_list += [nn.Linear(n_inner_neurons, n_inner_neurons),nn.BatchNorm1d(n_inner_neurons), nn.ReLU()]

        layer_list += [nn.Linear(n_inner_neurons, n_out)]
        if bound_last:
            layer_list += [nn.Sigmoid()]#`[nn.Sigmoid()]#Tanh()]

        self.layers = nn.ModuleList(layer_list)

    def forward(self, x):
        for cnt, layer in enumerate(self.layers):
            if cnt == 0:
                out = layer(x)
            else:
                out = layer(out)

        return out
    
    def fwd_sub(self, x):
        for cnt, layer in enumerate(self.layers):
            if cnt == 0:
                out = layer(x)
            else:
                out = layer(out)
                if cnt == len(self.layers)-2:
                    feats = out

        return feats, out


class DVIC_net(nn.Module):
    def __init__(self, n_in, n_out, n_hlayers, n_inner_neurons):
        super(DVIC_net, self).__init__()

        layer_list = [nn.Linear(n_in, n_inner_neurons), nn.BatchNorm1d(n_inner_neurons), nn.ReLU()]#Leaky

        for cnt in range(n_hlayers):
            layer_list += [nn.Linear(n_inner_neurons, n_inner_neurons),nn.BatchNorm1d(n_inner_neurons), nn.ReLU()]

        layer_list += [nn.Linear(n_inner_neurons, n_out)]
        layer_list += [nn.Sigmoid()] # nn.BatchNorm1d(1),
        

        self.layers = nn.ModuleList(layer_list)

    def forward(self, x, device):
        
        if x.shape[1] == 2:
            x2 = torch.index_select(x,1, torch.tensor([1,0], device=device))
        else:
            dsize = x.shape[1]-2
            vecind = np.concatenate((np.arange(dsize),[dsize+1, dsize]))
            x2 = torch.index_select(x, 1, torch.tensor(vecind, device=device))
                                    
        for cnt, layer in enumerate(self.layers):
            if cnt == 0:
                out = layer(x)
                out2 = layer(x2)
            else:
                out = layer(out) 
                out2 = layer(out2)

        return 0.5*(out+out2)
   
    
    
    
def train_model(train_loader, model, loss_fn, optimizer, device):
    '''Performs one training pass for a model'''
    
    train_loss = 0
    model.train()
    for (X,y) in train_loader:

        if y.shape[0]>1:
            optimizer.zero_grad()

            X = X.to(device)
            y = y.to(device)

            out = model(X)
            loss = loss_fn(out, y)
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

    return train_loss


def test_model(test_loader, model, loss_fn, device):
    '''Performs a pass of the test set '''
    
    test_loss = 0
    with torch.no_grad():
        model.eval()
        for (X, y) in test_loader:
            X = X.to(device)
            y = y.to(device)

            out = model(X)            
            loss = loss_fn(out, y)
            test_loss += loss.item()

    return test_loss # loss 

def train_regressor(trainDS, testDS, device, netconf, get_best=False):
    '''Trains a regressor with the MSE loss. 
    
    + If get_best=False it returns only the training and test loss over the 
    number of epochs desired and the final regressor.
    + If  get_best=True, it returns the losses but also the regressor which 
    minimizes the loss over the test set and the epoch in which the loss
    is minimized counting 0 as the first epoch.'''
    
    learning_rate = netconf['l_rate']
    nepochs = netconf['nepochs']
    
    train_loader = data.DataLoader(trainDS, batch_size=netconf['batch_size'],
                                   shuffle=True,
                                   num_workers=0, 
                                   pin_memory=True)
    
    test_loader = data.DataLoader(testDS, batch_size=netconf['batch_size'],
                                  shuffle=False,
                                  pin_memory=True, num_workers=0)

    '''Configure the regressor and optimizer'''
    reg_network = DeepNetBN(n_in=trainDS.dset.shape[1], 
                            n_out=trainDS.out.shape[1],
                            n_hlayers=netconf['n_hlayers'],
                            n_inner_neurons=netconf['n_inner_neurons'])

    loss_fn = nn.MSELoss()

    reg_network.to(device)
    optimizer = torch.optim.Adam(params=reg_network.parameters(), 
                                 lr=learning_rate, weight_decay=netconf['wd'])


    train_loss = torch.zeros(size=(nepochs,), device=device)
    test_loss = torch.zeros(size=(nepochs,), device=device)
    
    pbar = tqdm(range(nepochs))
    
    ''' Train and test loop for the regressor'''   
    for cnt in pbar:
        train_loss[cnt] = train_model(train_loader, reg_network, loss_fn, 
                                      optimizer,device)
        train_loss[cnt] /= len(train_loader) #average loss per batch

        test_loss[cnt] = test_model(test_loader, reg_network, loss_fn, device)
        test_loss[cnt] /= len(test_loader) #average loss per batch
                
        pbar.set_description('Regressor: Avg Train Loss: {0:.4f}'.format(train_loss[cnt])+
                             ' Avg Test Loss: {0:.4f}'.format(test_loss[cnt]))
        
        if get_best:
            if cnt == 0:
                best_ind = 0
                best_loss = test_loss[cnt]
                best_pars = reg_network.state_dict()
            else:
                if best_loss > test_loss[cnt]:
                    best_ind = cnt
                    best_loss = test_loss[cnt]
                    best_pars = copy.deepcopy(reg_network.state_dict())

    if get_best:
        reg_network.load_state_dict(best_pars)
        return reg_network, train_loss, test_loss, best_ind
    else:
        return reg_network, train_loss, test_loss


def load_reg_model(n_in, n_out, reg_res, reg_par, device):
    """Loads the parameters of a regressor network based on a NN."""
    reg_network = DeepNetBN(n_in=n_in,
                                  n_out=n_out,
                                  n_hlayers=reg_res['conf']['n_hlayers'],
                                  n_inner_neurons=reg_res['conf']['n_inner_neurons'])
    
    if device == 'cpu':
        reg_network.load_state_dict(torch.load(reg_par, map_location='cpu'))
    else:
        if torch.cuda.is_available():
            if torch.cuda.device_count() == 1:
                reg_network.load_state_dict(torch.load(reg_par, map_location='cuda:0'))
            else:
                reg_network.load_state_dict(torch.load(reg_par))
                
    reg_network.to(device)
    return reg_network