from icecream import ic
import numpy as np
from torch.utils.data.dataset import random_split
from torch.utils.data.dataloader import DataLoader
import torch
from models.Nets import CNNMnist, MLP
from torch import nn
from numpy.lib.function_base import append
import os
from os import WIFCONTINUED
import copy
import csv


class CentralisedLearning:
    def __init__(self, args, dataset_train, dataset_test) -> None:
        self._args = args
        self._dataset_train = dataset_train
        # TODO: split the dataset_train into train and validation
        n_images = self._dataset_train.data.shape[0]

        train_size = int(
            np.floor(n_images * (1-self._args.val_split)))
        valid_size = n_images - train_size
        local_train, local_valid = random_split(
            dataset_train, [train_size, valid_size])

        # load train and validation set and test set
        self.ldr_train = DataLoader(
            local_train, batch_size=self._args.bs, shuffle=True)
        self.ldr_valid = DataLoader(
            local_valid, batch_size=self._args.bs)
        self.ldr_test = DataLoader(
            dataset_test, batch_size=self._args.bs)

        self._loss_func = nn.CrossEntropyLoss()

        len_in = 1
        for x in dataset_train[0][0].shape:
            len_in *= x

        self._net = MLP(dim_in=len_in, dim_out=self._args.num_classes).to(
            self._args.device)

        self._train_loss = {0: []}
        self._validation_loss = {0: []}
        self._test_loss = {0: []}
        self._test_accuracy = {0: []}

    def run(self):
        # run the training, ignoring decentralisation completely
        self.centralised_train()

        # write the stats
        self.write_stats()

    def centralised_train(self):
        self._net.train()

        optimizer = torch.optim.SGD(
            self._net.parameters(), lr=self._args.lr, momentum=self._args.momentum)
        # optimizer = torch.optim.Adam(self._net.parameters(), weight_decay=0.05)

        epoch_loss = []
        valid_loss = []
        test_loss = []
        test_accuracy = []
        current_valid_loss = np.Inf

        for iter in range(self._args.epochs):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(
                    self._args.device), labels.to(self._args.device)
                self._net.zero_grad()
                log_probs = self._net(images)
                loss = self._loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self._args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tValid: {:.6f}'.format(
                        iter, batch_idx *
                        len(images), len(self.ldr_train.dataset),
                        100. * batch_idx / len(self.ldr_train), loss.item(), self.validation(self._net, self.ldr_valid)))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

            # validation

            new_valid_loss, new_acc = self.validation(
                self._net, self.ldr_valid)
            if self._args.verbose:
                print('Validation Loss: {:.6f}', new_valid_loss)

            print(
                'Validation Loss: {:.6f}\tAccuracy: {:.6f}'.format(new_valid_loss, new_acc))
            if new_valid_loss > current_valid_loss or new_valid_loss == 0:
                if self._args.verbose:
                    print('Early stop at local epoch {}: for increasing valid loss [new:{:.6f} > prev:{:.6f}]'.format(
                        iter, new_valid_loss, current_valid_loss))
                # save last valid loss (the increased one)

                # rollback to prev best net
                self._net = copy.deepcopy(prev_net)
                current_valid_loss = new_valid_loss
                valid_loss.append(current_valid_loss)
                break
            else:

                # backup current net before next epoch
                prev_net = copy.deepcopy(self._net)
            current_valid_loss = new_valid_loss
            valid_loss.append(current_valid_loss)
            tloss, tacc = self.validation(self._net, self.ldr_test)
            test_loss.append(tloss)
            test_accuracy.append(tacc)
        # return epoch_loss, current_valid_loss
        self._train_loss[0] = epoch_loss
        self._validation_loss[0] = valid_loss
        self._test_loss[0] = test_loss
        self._test_accuracy[0] = test_accuracy

    def validation(self, net, dataset):
        batch_valid_loss = 0
        batch_accuracy = 0
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(dataset):
                images, labels = images.to(
                    self._args.device), labels.to(self._args.device)

                log_probs = net(images)
                y_prob = nn.Softmax(dim=1)(log_probs)
                y_pred = y_prob.argmax(1)
                accuracy = (labels == y_pred).type(torch.float).mean()

                loss = self._loss_func(log_probs, labels)
                batch_valid_loss += loss.item()
                batch_accuracy += accuracy.item()
            batch_valid_loss /= batch_idx+1
            batch_accuracy /= batch_idx+1

        return batch_valid_loss, batch_accuracy

    def write_stats(self):
        # write the train and validation loss as in the DEC case?

        filename = "stats/" + self._args.outfolder + \
            "/loss_centr_" + str(self._args.seed) + ".tsv"
        os.makedirs(os.path.dirname(filename), exist_ok=True)

        with open(filename, 'w') as f:
            wr = csv.writer(f)
            wr.writerow(['nodeid', 'time', 'loss', 'loss_type'])
            for k in self._train_loss.keys():
                for t in range(len(self._train_loss[k])):
                    wr.writerow([k, t, self._train_loss[k][t], 'train'])

            for k in self._validation_loss.keys():
                for t in range(len(self._validation_loss[k])):
                    wr.writerow(
                        [k, t, self._validation_loss[k][t], 'validation'])

            for k in self._test_loss.keys():
                for t in range(len(self._test_loss[k])):
                    wr.writerow([k, t, self._test_loss[k][t], 'test'])

            for k in self._test_accuracy.keys():
                for t in range(len(self._test_accuracy[k])):
                    wr.writerow([k, t, self._test_accuracy[k][t], 'accuracy'])
