#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torch import nn


def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg


# socially weighted federated averaging
def SocialFedAvg(models: list, trust: list):
    '''
    Computes the weighted federated averaged of the models
    received from the paiv's social graph, weighted by their social trust
    '''

    w_avg = copy.deepcopy(models[0])
    for k in w_avg.keys():
        w_avg[k] *= trust[0]

    for k in w_avg.keys():
        for i in range(1, len(models)):
            w_avg[k] += models[i][k]*trust[i]
        w_avg[k] = torch.div(w_avg[k], sum(trust))
    return w_avg
