#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import random
import copy
import numpy as np
import torch
from sklearn import metrics
from torch import autograd, nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []

        # local dataset split of train and validation
        train_size = int(np.floor(len(idxs) * (1-self.args.val_split)))
        valid_size = len(idxs) - train_size
        local_train, local_valid = random_split(
            DatasetSplit(dataset, idxs), [train_size, valid_size])

        # load train and validation set
        self.ldr_train = DataLoader(
            local_train, batch_size=self.args.local_bs, shuffle=True)
        self.ldr_valid = DataLoader(
            local_valid, batch_size=self.args.local_bs)

    def train(self, net):
        net.train()
        # save local copy of net for rollback when increasing validation loss
        prev_net = copy.deepcopy(net)
        # train and update
        optimizer = torch.optim.SGD(
            net.parameters(), lr=self.args.lr, momentum=0.5)

        epoch_loss = 0

        current_valid_loss = np.Inf
        for iter in range(self.args.local_ep):
            batch_loss = 0
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx *
                        len(images), len(self.ldr_train.dataset),
                        100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss += loss.item()
            epoch_loss += batch_loss/(batch_idx+1)

            new_valid_loss, new_acc = self.validation(net, self.ldr_valid)
            if self.args.verbose:
                print(
                    'Validation Loss: {:.6f}\tAccuracy: {:.6f}'.format(new_valid_loss, new_acc))

            if new_valid_loss > current_valid_loss or new_valid_loss == 0:
                if self.args.verbose:
                    print('Early stop at local epoch {}: for increasing valid loss [new:{:.6f} > prev:{:.6f}]'.format(
                        iter, new_valid_loss, current_valid_loss))
                # save last valid loss (the increased one)

                # rollback to prev best net
                net = copy.deepcopy(prev_net)
                current_valid_loss = new_valid_loss
                break
            else:

                # backup current net before next epoch
                prev_net = copy.deepcopy(net)
            current_valid_loss = new_valid_loss
        epoch_loss /= self.args.local_ep
        return net.state_dict(), epoch_loss, current_valid_loss

    # # TODO make this function indedependent
    # def validation(self, net, dataset):
    #     batch_valid_loss = []
    #     with torch.no_grad():
    #         net.eval()
    #         for batch_idx, (images, labels) in enumerate(dataset):
    #             images, labels = images.to(
    #                 self.args.device), labels.to(self.args.device)

    #             log_probs = net(images)
    #             loss = self.loss_func(log_probs, labels)
    #             batch_valid_loss.append(loss.item())
    #     return sum(batch_valid_loss)/len(batch_valid_loss)

    def validation(self, net, dataset):
        batch_valid_loss = 0
        batch_accuracy = 0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(dataset):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                log_probs = net(images)
                y_prob = nn.Softmax(dim=1)(log_probs)
                y_pred = y_prob.argmax(1)
                accuracy = (labels == y_pred).type(torch.float).mean()

                loss = self.loss_func(log_probs, labels)
                batch_valid_loss += loss.item()
                batch_accuracy += accuracy.item()
            batch_valid_loss /= batch_idx+1
            batch_accuracy /= batch_idx+1

        return batch_valid_loss, batch_accuracy
