from inversefed.nn import MetaMonkey
import torch
from collections import defaultdict, OrderedDict
import numpy as np
import torch.optim as optim
from copy import deepcopy

def local_train_net(net,selected,target_id,validloader,num_images,lr,local_steps,batch_size, device="cpu",loss_fn=None,useDP=0,priv_accountant=None,dp_config=None):
    #receieve the global net and return the agg one (dict)
    
    rt=deepcopy(net)

    fed_avg_freqs=1.0/len(selected)
    gpara=rt.state_dict()
    update_0=None
    l2norms = []
    for net_id in selected:
        print("Training network %s." % (str(net_id)))
        tar_net=deepcopy(net)
        tar_net.to(device)
        inputs,labels=[],[]
        target_id_ = target_id+net_id*num_images
        

        while len(labels) < num_images:
            img, label = validloader.dataset[target_id_]
            target_id_ += 1
            #if (label not in labels) :
                
            labels.append(torch.as_tensor((label,), device=device))
            inputs.append(img.to(device))
        #print(labels)
        loss_fn=torch.nn.CrossEntropyLoss(reduction='mean')
        optimizer =optim.SGD(tar_net.parameters(),lr=lr)
        #optimizer = optim.Adam(tar_net.parameters(), lr=lr)
        if useDP==3:
            
            
            totalNum=dp_config['totalNum']
            from pytorch_dp_master.torchdp.privacy_engine import PrivacyEngine
            privacy_engine = PrivacyEngine(
                tar_net,
                {'bs':batch_size,'totalData':totalNum},
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier=dp_config['noise_multiplier'],
                max_grad_norm=dp_config['max_grad_norm'],
                useDP=useDP
            )
            privacy_engine.attach(optimizer)
        
        for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                
                optimizer.zero_grad()
                start=j*batch_size
                #print(batch_size)
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]

                input = torch.stack(input)
                label = torch.cat(label)
                
                output = tar_net(input)
                label_ = label
                loss= loss_fn(output, label_)

                gradients = torch.autograd.grad(loss,tar_net.parameters(),retain_graph=True)
                if useDP != 0:
                    noise_multiplier=dp_config['noise_multiplier']
                    delta = dp_config['delta']
                if useDP==1:
                    import math
                    S=dp_config['sensitivity']
                    m=len(gradients)
                    s=[S/math.sqrt(m) for _ in range(m)]
                    grad_after=[]
                    for idx in range(m):
                        g=gradients[idx]
                        if len(g.shape)==0:
                            continue
                        cur_g=g.reshape(-1)
                    
                        l2_norm=torch.norm(cur_g)
                        cur_g = cur_g / max(1, l2_norm / s[idx])
                        g=cur_g.reshape(g.shape)
                        grad_after.append(g)
                    gradients=grad_after
                for p,grad in zip(tar_net.parameters(), gradients):
                    p.grad = grad
                
                optimizer.step()
                if useDP==3:
                    epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
                        delta
                    )
                    '''
                    print(
                        f"(Ɛ = {epsilon}, 𝛿 = {delta}) for α = {best_alpha}"
                    )
                    '''
        ''''''
        
        import statistics
               
        if useDP==2:

            update_i=OrderedDict((name,   param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_i=list(update_i.values())
            update_i=[p.detach() for p in update_i]
            norm_values=[]
            for i in range(len(update_i)):
                norm_values.append(torch.norm(update_i[i]))
            l2norm=statistics.median(norm_values)
            l2norms.append(l2norm)
            torch.nn.utils.clip_grad_norm_(tar_net.parameters(), l2norm)
            optimizer.step()

        
        if(net_id==0):
            #print('work')
            update_0=OrderedDict((name,   (param_origin - param)/lr)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_0=list(update_0.values())
            update_0=[p.detach() for p in update_0]
        net_para = tar_net.cpu().state_dict()
        # Agg
        if net_id == selected[0]:
            for key in net_para:
                gpara[key] = net_para[key] * fed_avg_freqs
        else:
            for key in net_para:
                gpara[key] += net_para[key] * fed_avg_freqs
    if useDP==2:
        sensitivity = statistics.median(l2norms).item()
        #print(sensitivity)
    #print(sensitivity)
    if useDP==1 or useDP==2 :
        if useDP==1:
            sigma=noise_multiplier*S/len(selected)
        else:
            sigma=noise_multiplier*sensitivity/len(selected)
        for key in gpara:
            with torch.no_grad():
                gpara[key].to(device)
                shape=gpara[key].shape
                noise=torch.autograd.Variable(torch.zeros(shape))
                noise.data.normal_(0.0, std=sigma)
                #noise.to(torch.device('cuda:0'))
                #print(gpara[key])
                #print(gpara[key].device,noise.device)
                
                gpara[key]+=noise.long()
        priv_accountant.accumulate_privacy_spending(noise_multiplier, len(selected))
        print("-----------", priv_accountant.get_privacy_spent(target_deltas=[delta]))
    return gpara, update_0
''''''
def local_train_net_no(net,selected,target_id,validloader,num_images,lr,local_steps,batch_size, device="cpu",loss_fn=None,vit_num=0):
    #receieve the global net and return the agg one (dict)
    
    rt=deepcopy(net)

    fed_avg_freqs=1.0/len(selected)
    gpara=rt.state_dict()
    update_0=None
    l2norms = []
    tar_0=[]
    bn_prior=[]
    ori_num_images=num_images
    for net_id in selected:
        #print("Training network %s." % (str(net_id)))
        tar_net=deepcopy(net)
        tar_net.to(device)
        inputs,labels=[],[]
        target_id_ = target_id+net_id*num_images
        tar_net.eval()
        
        if net_id == 0:
            import inversefed
            import torch.nn as nn
            bn_layers = []
            for module in tar_net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_layers.append(inversefed.BNStatisticsHook(module))
        ''''''
        if net_id == 0 and vit_num !=0:
            
            if vit_num>=batch_size:
                num_images=vit_num
                while len(labels) < vit_num:
                    
                    img, label = validloader.dataset[target_id_]
                    target_id_ += 1
                    #if (label not in labels) :
                    tar_0.append(target_id_-1)
                    labels.append(torch.as_tensor((label,), device=device))
                    inputs.append(img.to(device))
                #print(labels)
            #TODO 如果受害者的参数数量是无法满足batch_size，那么意味着受害者标签也可存在重复
            else:
                num_images=batch_size
                while len(labels) < num_images:
                    img, label = validloader.dataset[target_id_]
                    target_id_ += 1
                    #if (label not in labels) :
                    tar_0.append(target_id_-1)    
                    labels.append(torch.as_tensor((label,), device=device))
                    inputs.append(img.to(device))
        else:
            num_images=ori_num_images
            while len(labels) < num_images:
                
                img, label = validloader.dataset[target_id_]
                target_id_ += 1
                if (label not in labels) :
                    if(net_id==0):
                        tar_0.append(target_id_-1)    
                    labels.append(torch.as_tensor((label,), device=device))
                    inputs.append(img.to(device))
        #print(labels)
        loss_fn=torch.nn.CrossEntropyLoss(reduction='mean')
        #optimizer =optim.Adam(tar_net.parameters(),lr=lr,)
        optimizer =optim.SGD(tar_net.parameters(), lr=lr)
        #optimizer = optim.Adam(tar_net.parameters(), lr=lr)
        
        for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                
                optimizer.zero_grad()
                start=j*batch_size
                #print(batch_size)
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]

                input = torch.stack(input)
                label = torch.cat(label)
                #print(input.shape)
                #print(label)
                #print(input.shape)
                output = tar_net(input)
                label_ = label
                loss= loss_fn(output, label_)

                gradients = torch.autograd.grad(loss,tar_net.parameters(),retain_graph=True)
                
                for p,grad in zip(tar_net.parameters(), gradients):
                    p.grad = grad
                
                #for (name, module),grad in zip(tar_net.named_modules(), gradients):
                    #if isinstance(module, torch.nn.BatchNorm2d):
                        #print(gradients)
                        #print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
                        #exit()
                
                optimizer.step()

        ''''''
        
        import statistics
               
       
        if(net_id==0):
            
            '''
            for idx, mod in enumerate(bn_layers):
                    #print(mod.mean_var)
                    mean_var = mod.mean_var[0].detach(), mod.mean_var[1].detach()
                    #print(mean_var)
                    bn_prior.append(mean_var)
            '''
            '''
            
            for name, module in tar_net.named_modules():
                if isinstance(module, torch.nn.BatchNorm2d):
                    bn_params_dict = module.state_dict()
                    #print(bn_params_dict)
                    #exit()
                    mean_var=[bn_params_dict['running_mean'],bn_params_dict['running_var']]
                    bn_prior.append(mean_var)
            '''


            #print(bn_prior)
            #exit()
            #print('work')
            update_0=OrderedDict((name, param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_0=list(update_0.values())
            update_0=[p.detach() for p in update_0]
        net_para = tar_net.cpu().state_dict()
        
        # Agg
        if net_id == selected[0]:
            for key in net_para:
                gpara[key] = net_para[key] * fed_avg_freqs
        else:
            for key in net_para:
                gpara[key] += net_para[key] * fed_avg_freqs
    
    return gpara, [update_0,tar_0,bn_prior]
def local_train_net_CDPL(net,selected,target_id,validloader,num_images,lr,local_steps,batch_size, device="cpu",loss_fn=None,priv_accountant=None,dp_config=None):
    #receieve the global net and return the agg one (dict)
    #useDP==1
    rt=deepcopy(net)

    fed_avg_freqs=1.0/len(selected)
    noise_multiplier=dp_config['noise_multiplier']
    delta = dp_config['delta']
    gpara=rt.state_dict()
    update_0=None
    tar_0=[]
    l2norms = []
    for net_id in selected:
        print("Training network %s." % (str(net_id)))
        tar_net=deepcopy(net)
        tar_net.to(device)
        tar_net.eval()
        inputs,labels=[],[]
        target_id_ = target_id+net_id*num_images
        if net_id == 0:
            import inversefed
            import torch.nn as nn
            bn_layers = []
            for module in tar_net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_layers.append(inversefed.BNStatisticsHook(module))
        while len(labels) < num_images:
            img, label = validloader.dataset[target_id_]
            if(net_id==0):
                tar_0.append(target_id_)
            target_id_ += 1
            #if (label not in labels) :
            labels.append(torch.as_tensor((label,), device=device))
            inputs.append(img.to(device))
        loss_fn=torch.nn.CrossEntropyLoss(reduction='mean')
        optimizer =optim.SGD(tar_net.parameters(),lr=lr)
        
        for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                
                optimizer.zero_grad()
                start=j*batch_size
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]

                input = torch.stack(input)
                label = torch.cat(label)
                
                output = tar_net(input)
                label_ = label
                loss= loss_fn(output, label_)

                gradients = torch.autograd.grad(loss,tar_net.parameters(),retain_graph=True)

                    
                
                import math
                S=dp_config['sensitivity']
                m=len(gradients)
                s=[S/math.sqrt(m) for _ in range(m)]
                grad_after=[]
                for idx in range(m):
                    g=gradients[idx]
                    if len(g.shape)==0:
                        continue
                    cur_g=g.reshape(-1)
                    
                    l2_norm=torch.norm(cur_g)
                    cur_g = cur_g / max(1, l2_norm / s[idx])
                    g=cur_g.reshape(g.shape)
                    grad_after.append(g)
                gradients=grad_after
                for p,grad in zip(tar_net.parameters(), gradients):
                    p.grad = grad
                
                optimizer.step()
        ''''''
        if(net_id==0):
            
            bn_prior=[]
            
            for idx, mod in enumerate(bn_layers):
                    mean_var = mod.mean_var[0].detach(), mod.mean_var[1].detach()
                    #print(mean_var)
                    bn_prior.append(mean_var)
            #print('work')
            update_0=OrderedDict((name,   param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_0=list(update_0.values())
            update_0=[p.detach() for p in update_0]
        net_para = tar_net.cpu().state_dict()
        # Agg
        if net_id == selected[0]:
            for key in net_para:
                gpara[key] = net_para[key] * fed_avg_freqs
        else:
            for key in net_para:
                gpara[key] += net_para[key] * fed_avg_freqs


        
    sigma=noise_multiplier*S/len(selected)

    for key in gpara:
        with torch.no_grad():
            gpara[key].to(device)
            shape=gpara[key].shape
            noise=torch.autograd.Variable(torch.zeros(shape))
            noise.data.normal_(0.0, std=sigma)
            gpara[key]+=noise.long()
    priv_accountant.accumulate_privacy_spending(noise_multiplier, len(selected))
    print("-----------", priv_accountant.get_privacy_spent(target_deltas=[delta]))
    return gpara, [update_0,tar_0,bn_prior]
def local_train_net_CDPM(net,selected,target_id,validloader,num_images,lr,local_steps,batch_size, device="cpu",loss_fn=None,priv_accountant=None,dp_config=None):
    #receieve the global net and return the agg one (dict)
    #useDP==2
    rt=deepcopy(net)

    fed_avg_freqs=1.0/len(selected)
    noise_multiplier=dp_config['noise_multiplier']
    delta = dp_config['delta']
    gpara=rt.state_dict()
    update_0=None
    l2norms = []
    tar_0=[]

    for net_id in selected:
        print("Training network %s." % (str(net_id)))
        tar_net=deepcopy(net)
        tar_net.to(device)
        tar_net.eval()
        inputs,labels=[],[]
        target_id_ = target_id+net_id*num_images
        if net_id == 0:
            import inversefed
            import torch.nn as nn
            bn_layers = []
            for module in tar_net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_layers.append(inversefed.BNStatisticsHook(module))
        while len(labels) < num_images:
            img, label = validloader.dataset[target_id_]
            if(net_id==0):
                tar_0.append(target_id_)
            target_id_ += 1
            #if (label not in labels) :
                
            labels.append(torch.as_tensor((label,), device=device))
            inputs.append(img.to(device))
        #print(labels)
        loss_fn=torch.nn.CrossEntropyLoss(reduction='mean')
        optimizer =optim.SGD(tar_net.parameters(),lr=lr)
        
        for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                
                optimizer.zero_grad()
                start=j*batch_size
                #print(batch_size)
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]

                input = torch.stack(input)
                label = torch.cat(label)
                
                output = tar_net(input)
                label_ = label
                loss= loss_fn(output, label_)

                gradients = torch.autograd.grad(loss,tar_net.parameters(),retain_graph=True)

                    
                
                import math
                S=dp_config['sensitivity']
                m=len(gradients)
                
                Cg=[]
                for idx in range(m):
                    g=gradients[idx]
                    if len(g.shape)==0:
                        continue
                    cur_g=g.reshape(-1)
                    Cg.append(cur_g)
                    
                #print(Cg)
                Cg=torch.cat(Cg, dim=0)
                #print(Cg.size())
                l2_norm=torch.norm(Cg)

                #print(l2_norm)
                grad_after=[]
                for idx in range(m):
                    g=gradients[idx]
                    if len(g.shape)==0:
                        continue
                    cur_g=g.reshape(-1)
                    l2norms.append(l2_norm)
                    cur_g = cur_g / max(1, l2_norm / S)
                    g=cur_g.reshape(g.shape)
                    grad_after.append(g)
                gradients=grad_after
                for p,grad in zip(tar_net.parameters(), gradients):
                    p.grad = grad
                
                optimizer.step()
        
        import statistics
                      
        if(net_id==0):
            #print('work')
            bn_prior=[]
            
            for idx, mod in enumerate(bn_layers):
                    mean_var = mod.mean_var[0].detach(), mod.mean_var[1].detach()
                    #print(mean_var)
                    bn_prior.append(mean_var)
            update_0=OrderedDict((name,param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_0=list(update_0.values())
            update_0=[p.detach() for p in update_0]
        net_para = tar_net.cpu().state_dict()
        # Agg
        if net_id == selected[0]:
            for key in net_para:
                gpara[key] = net_para[key] * fed_avg_freqs
        else:
            for key in net_para:
                gpara[key] += net_para[key] * fed_avg_freqs
    print(torch.mean(torch.stack(l2norms)).item())
    sigma=noise_multiplier*S/len(selected)

    for key in gpara:
        with torch.no_grad():
            gpara[key].to(device)
            shape=gpara[key].shape
            noise=torch.autograd.Variable(torch.zeros(shape))
            noise.data.normal_(0.0, std=sigma)
            gpara[key]+=noise.long()
    priv_accountant.accumulate_privacy_spending(noise_multiplier, len(selected))
    print("-----------", priv_accountant.get_privacy_spent(target_deltas=[delta]))
    return gpara, [update_0,tar_0,bn_prior]
def local_train_net_LDPL(net,selected,target_id,validloader,num_images,lr,local_steps,batch_size, device="cpu",loss_fn=None,priv_accountant=None,dp_config=None):
    #receieve the global net and return the agg one (dict)
    
    rt=deepcopy(net)

    fed_avg_freqs=1.0/len(selected)
    noise_multiplier=dp_config['noise_multiplier']
    delta = dp_config['delta']
    gpara=rt.state_dict()
    update_0=None
    l2norms = []
    tar_0=[]
    for net_id in selected:
        print("Training network %s." % (str(net_id)))
        tar_net=deepcopy(net)
        tar_net.to(device)
        tar_net.eval()
        inputs,labels=[],[]
        target_id_ = target_id+net_id*num_images
        
        if net_id == 0:
            import inversefed
            import torch.nn as nn
            bn_layers = []
            for module in tar_net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_layers.append(inversefed.BNStatisticsHook(module))
        while len(labels) < num_images:
            img, label = validloader.dataset[target_id_]
            if(net_id==0):
                tar_0.append(target_id_)
            target_id_ += 1
            #if (label not in labels) :
                
            labels.append(torch.as_tensor((label,), device=device))
            inputs.append(img.to(device))
        #print(labels)
        loss_fn=torch.nn.CrossEntropyLoss(reduction='mean')
        optimizer =optim.SGD(tar_net.parameters(),lr=lr)

       
            
            
        totalNum=dp_config['totalNum']
        from pytorch_dp_master.torchdp.privacy_engine import PrivacyEngine
        privacy_engine = PrivacyEngine(
            tar_net,
            {'bs':batch_size,'totalData':totalNum},
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=dp_config['noise_multiplier'],
            max_grad_norm=dp_config['max_grad_norm'],
            useDP=3
        )
        privacy_engine.attach(optimizer)
        
        for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                
                optimizer.zero_grad()
                start=j*batch_size
                #print(batch_size)
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]

                input = torch.stack(input)
                label = torch.cat(label)
                
                output = tar_net(input)
                label_ = label
                loss= loss_fn(output, label_)

                gradients = torch.autograd.grad(loss,tar_net.parameters(),retain_graph=True)
                
                for p,grad in zip(tar_net.parameters(), gradients):
                    p.grad = grad
                
                optimizer.step()
            if local_steps==1 and num_images==batch_size:
                #FLSGD privacy buget
                import math
                epsilon=math.sqrt(2*math.log10(1.25/dp_config['delta']))/(dp_config['noise_multiplier']*dp_config['max_grad_norm'])
                #print(f"Ɛ = {epsilon}")
            else:
                epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
                    delta
                )
                '''
                print(
                    f"(Ɛ = {epsilon}, 𝛿 = {delta}) for α = {best_alpha}"
                )
                '''
        ''''''

        if(net_id==0):
            #print('work')
            bn_prior=[]
            
            for idx, mod in enumerate(bn_layers):
                    mean_var = mod.mean_var[0].detach(), mod.mean_var[1].detach()
                    #print(mean_var)
                    bn_prior.append(mean_var)
            update_0=OrderedDict((name,  param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_0=list(update_0.values())
            update_0=[p.detach() for p in update_0]
        net_para = tar_net.cpu().state_dict()
        # Agg
        if net_id == selected[0]:
            for key in net_para:
                gpara[key] = net_para[key] * fed_avg_freqs
        else:
            for key in net_para:
                gpara[key] += net_para[key] * fed_avg_freqs
    
    return gpara, [update_0,tar_0,bn_prior]
def local_train_net_LDPM(net,selected,target_id,validloader,num_images,lr,local_steps,batch_size, device="cpu",loss_fn=None,priv_accountant=None,dp_config=None):
    #receieve the global net and return the agg one (dict)
    
    rt=deepcopy(net)

    fed_avg_freqs=1.0/len(selected)
    noise_multiplier=dp_config['noise_multiplier']
    delta = dp_config['delta']
    gpara=rt.state_dict()
    update_0=None
    l2norms = []
    tar_0=[]
    for net_id in selected:
        print("Training network %s." % (str(net_id)))
        tar_net=deepcopy(net)
        tar_net.to(device)
        tar_net.eval()
        inputs,labels=[],[]
        target_id_ = target_id+net_id*num_images
        
        if net_id == 0:
            import inversefed
            import torch.nn as nn
            bn_layers = []
            for module in tar_net.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_layers.append(inversefed.BNStatisticsHook(module))
        while len(labels) < num_images:
            img, label = validloader.dataset[target_id_]
            if(net_id==0):
                tar_0.append(target_id_)
            target_id_ += 1
            #if (label not in labels) :
                
            labels.append(torch.as_tensor((label,), device=device))
            inputs.append(img.to(device))
        #print(labels)
        loss_fn=torch.nn.CrossEntropyLoss(reduction='mean')
        optimizer =optim.SGD(tar_net.parameters(),lr=lr)

       
            
            
        totalNum=dp_config['totalNum']
        from pytorch_dp_master.torchdp.privacy_engine import PrivacyEngine
        privacy_engine = PrivacyEngine(
            tar_net,
            {'bs':batch_size,'totalData':totalNum},
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=dp_config['noise_multiplier'],
            max_grad_norm=dp_config['max_grad_norm'],
            useDP=4
        )
        privacy_engine.attach(optimizer)
        
        for i in range(local_steps):
            #shuffer
            
            #state = np.random.get_state()
            #np.random.shuffle(inputs)
            #np.random.set_state(state)
            #np.random.shuffle(labels)       
            
            for j in range(num_images//batch_size):
                
                optimizer.zero_grad()
                start=j*batch_size
                #print(batch_size)
                input=inputs[start:start+batch_size]
                label=labels[start:start+batch_size]

                input = torch.stack(input)
                label = torch.cat(label)
                
                output = tar_net(input)
                label_ = label
                loss= loss_fn(output, label_)

                gradients = torch.autograd.grad(loss,tar_net.parameters(),retain_graph=True)
                
                for p,grad in zip(tar_net.parameters(), gradients):
                    p.grad = grad
                
                optimizer.step()
                
            if local_steps==1 and num_images==batch_size:
                #FLSGD privacy buget
                import math
                epsilon=math.sqrt(2*math.log10(1.25/dp_config['delta']))/(dp_config['noise_multiplier']*dp_config['max_grad_norm'])
                #print(f"Ɛ = {epsilon}")
            else:
                epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(
                    delta
                )
                '''
                print(
                    f"(Ɛ = {epsilon}, 𝛿 = {delta}) for α = {best_alpha}"
                )
                '''
        ''''''

        if(net_id==0):
            #print('work')
            bn_prior=[]
            
            for idx, mod in enumerate(bn_layers):
                    mean_var = mod.mean_var[0].detach(), mod.mean_var[1].detach()
                    #print(mean_var)
                    bn_prior.append(mean_var)
            update_0=OrderedDict((name,  param_origin - param)
                                               for ((name, param), (name_origin, param_origin))
                                               in zip(OrderedDict(tar_net.named_parameters()).items(), OrderedDict(net.named_parameters()).items()))
            update_0=list(update_0.values())
            update_0=[p.detach() for p in update_0]
        net_para = tar_net.cpu().state_dict()
        # Agg
        if net_id == selected[0]:
            for key in net_para:
                gpara[key] = net_para[key] * fed_avg_freqs
        else:
            for key in net_para:
                gpara[key] += net_para[key] * fed_avg_freqs
    
    return gpara, [update_0,tar_0,bn_prior]
def compute_accuracy(model, dataloader, get_confusion_matrix=False, moon_model=False, device="cpu"):
    model.eval()
    

    true_labels_list, pred_labels_list = np.array([]), np.array([])

    if type(dataloader) == type([1]):
        pass
    else:
        dataloader = [dataloader]

    correct, total = 0, 0
    with torch.no_grad():
        for tmp in dataloader:
            for batch_idx, (x, target) in enumerate(tmp):
                x, target = x.to(device), target.to(device,dtype=torch.int64)
                if moon_model:
                    _, _, out = model(x)
                else:
                    #print(x.size())
                    out = model(x)
                _, pred_label = torch.max(out.data, 1)

                total += x.data.size()[0]
                correct += (pred_label == target.data).sum().item()

                if device == "cpu":
                    pred_labels_list = np.append(pred_labels_list, pred_label.numpy())
                    true_labels_list = np.append(true_labels_list, target.data.numpy())
                else:
                    pred_labels_list = np.append(pred_labels_list, pred_label.cpu().numpy())
                    true_labels_list = np.append(true_labels_list, target.data.cpu().numpy())

    

    

    if get_confusion_matrix:
        return correct/float(total)

    return correct/float(total)
