import argparse
from collections import deque
from abc import ABC, abstractmethod
from typing import List
import torch
from torch import nn

from torch.nn import Module
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from graph.sai_graph_generator import SAIGraph
from models.Update import LocalUpdate
from models.Fed import FedAvg, SocialFedAvg
import numpy as np
#from utils.messaging import Message
from message import Message
from models import Fed, Update
import copy

from queue import PriorityQueue


class AbstractPaiv(ABC):
    def __init__(self, id: int, args: argparse.Namespace, graph: SAIGraph, dataset: Dataset, data_idxs: List, model: Module):
        # PAIV id
        self._id = id
        # input args
        self._args = args
        # self trust
        self._selfconfidence = -1
        # graph object
        self._graph = graph
        # list of idxs of the dataset
        self._dataset = dataset
        self._data_idxs = data_idxs
        # local model
        self._model = model
        # buffer to collect other models
        self._msg_buffer = None

    @property
    @abstractmethod
    def id(self):
        pass

    @id.setter
    @abstractmethod
    def id(self, value):
        pass

    @property
    @abstractmethod
    def args(self):
        return self._args

    @property
    @abstractmethod
    def selfconfidence(self):
        pass

    @selfconfidence.setter
    @abstractmethod
    def selfconfidence(self, value):
        pass

    @property
    @abstractmethod
    def graph(self):
        pass

    @graph.setter
    @abstractmethod
    def graph(self, value):
        pass

    @property
    @abstractmethod
    def dataset(self):
        pass

    @dataset.setter
    @abstractmethod
    def dataset(self, value):
        pass

    @property
    @abstractmethod
    def data_idxs(self):
        pass

    @data_idxs.setter
    @abstractmethod
    def data_idxs(self, value):
        pass

    @property
    @abstractmethod
    def model(self):
        pass

    @model.setter
    @abstractmethod
    def model(self, value):
        pass

    @property
    @abstractmethod
    def msg_buffer(self):
        pass

    @msg_buffer.setter
    @abstractmethod
    def msg_buffer(self, value):
        pass

    @abstractmethod
    def receive(message):
        pass

    @abstractmethod
    def train():
        pass

    @abstractmethod
    def reset_model():
        pass


class simplePaiv(AbstractPaiv):

    def __init__(self, id, args, graph, dataset, data_idxs, model, train_role):
        super().__init__(id=id, args=args, graph=graph, dataset=dataset, data_idxs=data_idxs,
                         model=model)
        self._msg_buffer = PriorityQueue()

        # training related properties
        self._loss_local = []
        # self._valid_loss = np.Inf

        self._loss_func = nn.CrossEntropyLoss()
        self._local_strategy = LocalUpdate(
            self._args, dataset=self._dataset, idxs=self._data_idxs)

        self._best_valid_loss = np.Inf

        self._train_role = train_role

    @property
    def id(self):
        return self._id

    @property
    def args(self):
        return self._args

    @property
    def selfconfidence(self):
        return self._selfconfidence

    @property
    def graph(self):
        return self._graph

    @property
    def dataset(self):
        return self._dataset

    @property
    def data_idxs(self):
        return self._data_idxs

    @property
    def model(self):
        return self._model

    @property
    def msg_buffer(self):
        return self._msg_buffer

    @property
    def best_valid_loss(self):
        return self._best_valid_loss

    @property
    def train_role(self):
        return self._train_role

    @id.setter
    def id(self, value):
        self._id = value

    @selfconfidence.setter
    def selfconfidence(self, value):
        self._selfconfidence = value

    @graph.setter
    def graph(self, value):
        self._graph = value

    @dataset.setter
    def dataset(self, value):
        self._dataset = value

    @data_idxs.setter
    def data_idxs(self, value):
        self._data_idxs = value

    @model.setter
    def model(self, value):
        self._model = value

    @msg_buffer.setter
    def msg_buffer(self, value):
        self._msg_buffer = value

    @selfconfidence.setter
    def selfconfidence(self, value):
        self._selfconfidence = value

    @best_valid_loss.setter
    def best_valid_loss(self, value):
        self._best_valid_loss = value

    @train_role.setter
    def train_role(self, value):
        self._train_role = value

    def receive(self, message):
        self._msg_buffer.append(message)

    def clear_msg_buffer(self):
        self._msg_buffer.clear()

    def train(self, current_time):
        # here we implement the atomic training that includes interaction between local and peers' models
        # workflow:
        # 1) process peers' models
        # 2) train local model embedding peers' knowledge
        if self.args.verbose:
            print("PAIV", self._id, "/ The msg_buffer size is:",
                  len(self._msg_buffer.queue))
        if not self._msg_buffer.empty():
            # implement logic for selecting the models to aggregate
            # simple: use all models in the buffer
            models = {}
            active_neighs = list()
            # note: the [:] is necessary, it creates a copy of
            for msg in self._msg_buffer.queue[:]:
                # self._msg_buffer.queue and the removal with the .get() is applied
                # only when msg.time < current_time
                if msg.time < current_time:
                    if self.args.verbose:
                        print("TRAIN: adding model from", msg.source,
                              "generated at time", msg.time)
                    # associate model with their source node for weighting
                    models[msg.source] = msg.model.state_dict()
                    active_neighs.append(msg.source)
                    self._msg_buffer.get()  # remove the processed message from the message buffer
            if self.args.verbose:
                print("-- ", self._msg_buffer.qsize(),
                      "messages already available for the next training epoch")

            if self._train_role == 'standalone':
                # aggregate models using social aggregation
                self._social_aggregation(
                    active_neighs=active_neighs, models=models)
            elif self._train_role == 'fed_client':
                # simply substitute local with global model
                self._substitute_local_with_global_model(models=models)

        # update local model starting from new aggregate model

        w, loss, vloss = self._local_strategy.train(
            net=copy.deepcopy(self._model).to(self._args.device))
        self._model.load_state_dict(w)

        # if vloss < self._best_valid_loss:
        self._best_valid_loss = vloss
        # return the loss after the training epoch
        return loss

    def validate(self):
        return self._best_valid_loss

    def test(self, dataset):
        # TODO: here we apply the local model to the validation set
        batch_valid_loss = 0
        batch_accuracy = 0
        ldr_dataset = DataLoader(
            dataset, batch_size=self.args.local_bs)
        with torch.no_grad():
            for batch_idx, (images, labels) in enumerate(ldr_dataset):
                images, labels = images.to(
                    self.args.device), labels.to(self.args.device)

                log_probs = self._model(images)

                y_prob = nn.Softmax(dim=1)(log_probs)
                y_pred = y_prob.argmax(1)
                accuracy = (labels == y_pred).type(torch.float).mean()

                loss = self._loss_func(log_probs, labels)
                batch_valid_loss += loss.item()
                batch_accuracy += accuracy.item()

            batch_valid_loss /= batch_idx+1
            batch_accuracy /= batch_idx+1

        return batch_valid_loss, batch_accuracy

    def get_dst_list(self):
        return self._graph.neighbors(self._id)

    def get_max_trust_in_neigh(self):
        trust_list = [self._graph[self._id][n]['weight']
                      for n in list(self._graph.neighbors(self._id))]
        return np.max(trust_list)

    def get_trust_in_neighs(self):
        # everything static so far, so we can compute it at init and use it directly
        trust_dict = {n: self._graph[self._id][n]['weight']
                      for n in list(self._graph.neighbors(self._id))}
        return trust_dict

    def reset_model():
        pass

    def _social_aggregation(self, active_neighs, models):
        trust_dict = self.get_trust_in_neighs()
        trust_active_dict = {k: trust_dict[k]
                             for k in active_neighs if k in trust_dict}

        # prepare models and trust values for averaging
        w_list = [self._model.state_dict()]
        t_list = [self._selfconfidence]
        for k in models.keys():
            w_list.append(models[k])
            t_list.append(trust_active_dict[k])

        # do the averaging
        model_avg = SocialFedAvg(w_list, t_list)
        self._model.load_state_dict(model_avg)

    def _fedavg_aggregation(self, models):
        w_list = []
        for k in models.keys():
            w_list.append(models[k])

        model_avg = FedAvg(w_list)
        self._model.load_state_dict(model_avg)

    def _substitute_local_with_global_model(self, models: dict):
        _, mdl = models.popitem()
        self._model.load_state_dict(mdl)
