"""Run reconstruction in a terminal prompt.
Optional arguments can be found in inversefed/options.py

This CLI can recover the baseline experiments.
"""

import torch
import torchvision
import torch.nn as nn

import numpy as np

import inversefed
torch.backends.cudnn.benchmark = inversefed.consts.BENCHMARK

from collections import defaultdict
import datetime
import time
import os
import json
import hashlib
import csv
import copy
import pickle
import fedavg

nclass_dict = {'I32': 1000, 'I64': 1000, 'I128': 1000,'I256':1000,
               'CIFAR10': 10, 'CIFAR100': 100, 'CA': 8, 'ImageNet':1000,
               'FFHQ': 10, 'FFHQ64': 10, 'FFHQ128': 10,'MNIST':10,
               'MNIST_GRAY':10,'CelebA':2,
               }
# Parse input arguments
parser = inversefed.options()
parser.add_argument('--unsigned', action='store_true', help='Use signed gradient descent')
parser.add_argument('--num_exp', default=1, type=int, help='Number of consecutive experiments')
parser.add_argument('--max_iterations', default=500, type=int, help='Maximum number of iterations for reconstruction.')
parser.add_argument('--gias_iterations', default=0, type=int, help='Maximum number of gias iterations for reconstruction.')
parser.add_argument('--seed', default=1234, type=float, help='Local learning rate for federated averaging')

parser.add_argument('--batch_size', default=4, type=int, help='Number of mini batch for federated averaging')
parser.add_argument('--local_lr', default=1e-2, type=float, help='Local learning rate for federated averaging')
parser.add_argument('--useFL', action='store_true', help='useFL==True, switch to FL mode')
parser.add_argument('--clients', default=5, type=int, help='Number of cliets for federated averaging')
parser.add_argument('--rounds', default=10, type=int, help='Number of communication rounds for federated averaging')
parser.add_argument('--sample', default=0.3, type=float, help='pop of joined clients for federated averaging')
parser.add_argument('--vit_num', default=0, type=int, help='vit bench')
parser.add_argument('--checkpoint_path', default='', type=str, help='Local learning rate for federated averaging')
parser.add_argument('--device', default='0', type=str, help='gpu')
parser.add_argument('--pth', default=0, type=int, help='pthRound')
parser.add_argument('--skip', default=-1, type=int, help='cut skip id')
parser.add_argument('--config_id', default=0, type=int, help='for ConCNN')
args = parser.parse_args()
if args.target_id is None:
    args.target_id = 0
args.save_image = True
args.signed = not args.unsigned


# Parse training strategy
defs = inversefed.training_strategy('conservative')
defs.epochs = args.epochs
def landscape(model, data, label, args,filename,K, p=1e-3):
    filename=filename.replace(':','-')
    import matplotlib.pyplot as plt
    import numpy as np
    from mpl_toolkits.mplot3d import Axes3D
    import math
    print("landscape painting")
    model.eval()
    
    
    outputs = model(data)
    #print(outputs, label)
    #exit()
    loss = torch.nn.functional.cross_entropy(outputs, label)
    gradient = torch.autograd.grad(loss, model.parameters())
    list_gradient = [grad.detach() for grad in gradient]
    epsilon=0.1
    x = np.linspace(-epsilon, epsilon, 1000)
    z = np.zeros_like(x)
    noise=torch.rand_like(data)
    for i in range(1000):
        
            perturbed_data = data + noise*torch.tensor(x[i])

            outputs_k = model(perturbed_data)
            loss_k = torch.nn.functional.cross_entropy(outputs_k, label)
            gradient_k = torch.autograd.grad(loss_k, model.parameters())
            list_gradient_k = [grad.detach() for grad in gradient_k]
            diffs = [(grad_k - grad).pow(2).sum() for grad_k,grad in zip(list_gradient_k,list_gradient)]
            summ=0
            for diff in diffs:
                summ+=diff

            z[i] =math.sqrt(summ)
    plt.figure(figsize=(8, 6))
    plt.plot(x, z, label='Real Part')
    plt.xlabel('x')
    plt.ylabel('z')
    plt.title('z-x Curve')
    plt.legend()

    plt.savefig("landscape.png")
    plt.show() 
    exit()       
    return None
def calculate_jacobian_norm(model, data, label, args,filename,K, p=1e-3):
    filename=filename.replace(':','-')

    import matplotlib.pyplot as plt
    print("strat_calculate_jacobian_norm")

    model.eval()
    

    
    outputs = model(data)
    #print(outputs, label)
    #exit()
    loss = torch.nn.functional.cross_entropy(outputs, label)
    gradient = torch.autograd.grad(loss, model.parameters())
    list_gradient = [grad.detach() for grad in gradient]
    

    num_layers = len(list(model.parameters()))

    mean_norms = [0] * num_layers
    
    
    for k in range(K):

        noise = torch.randn_like(data) * p
        perturbed_data = data + noise.to(data.device)
        

        
        outputs_k = model(perturbed_data)
        loss_k = torch.nn.functional.cross_entropy(outputs_k, label)
        gradient_k = torch.autograd.grad(loss_k, model.parameters())
        list_gradient_k = [grad.detach() for grad in gradient_k]
        epsilon = 1e-8
        #diffs = [(grad_k - grad)/torch.norm(param, 2) for grad_k,grad,param in zip(list_gradient_k,list_gradient,model.parameters())]
        diffs = [(grad_k - grad) for grad_k,grad in zip(list_gradient_k,list_gradient)]
        norms = [torch.norm(diff, 2) for diff in diffs]
        mean_norms = [mean_norm + norm for mean_norm, norm in zip(mean_norms, norms)]
    
    mean_norms = [mean_norm / K for mean_norm in mean_norms]
    os.makedirs(os.path.join(args.result_path, filename), exist_ok=True)
    txt_path=os.path.join(args.result_path, filename, 'norms'+'_'+str(args.target_id)+'_'+str(args.pth)+'.txt')
        
    with open(txt_path, 'a+') as f:
        for norm in mean_norms:
                f.write(str(norm) + '\n')
    #mean_norms=[mean_norms[i] for i in layer_ids]
    #mean_norms=[mean_norms[i] for i in range(num_layers) if i%2==0]
    print(len(mean_norms))
    plt.plot(mean_norms)
    plt.xlabel('Layer Number')
    plt.ylabel('Average Jacobian Norm')
    plt.savefig(os.path.join(args.result_path, filename,'jacobian_norms'+'_'+str(args.target_id)+'_'+str(args.pth)+'.png'))
    print("finish_calculate_jacobian_norm")
    return mean_norms

if __name__ == "__main__":
    # Choose GPU device and print status information:
    setup = inversefed.utils.system_startup(args)
    start_time = time.time()
    # Prepare for training
    filename="model:"+args.model+"_dataset:"+args.dataset+"_attack:"+args.optim+"_iteration:"+str(args.max_iterations)+"_localNum:"+str(args.num_images)+"_epoch:"+str(args.accumulation)+"_localBs:"+str(args.batch_size)+"_clients:"+str(args.clients)+"_sample:"+str(args.sample)+"_rounds:"+str(args.rounds)+"_vitNum:"+str(args.vit_num)+"_device"+args.device

    
    if args.dataset.startswith('I'):
        #restore imagenet
        #loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs, data_path='E:\\mydataset')
        loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs, data_path=args.data_path)
    else:
        loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs, data_path=args.data_path)

    model, model_seed = inversefed.construct_model(args.model, num_classes=nclass_dict[args.dataset], num_channels=3, seed=0,pth=args.pth,skip=args.skip,config_id=args.config_id)
    #print(model)
    #exit()
    #exit()
    if args.dataset.startswith('FFHQ'):
        dm = torch.as_tensor(getattr(inversefed.consts, f'cifar10_mean'), **setup)[:, None, None]
        ds = torch.as_tensor(getattr(inversefed.consts, f'cifar10_std'), **setup)[:, None, None]
    elif args.dataset.startswith('I'):

        '''
        dm= torch.tensor(
        [[[0.485]],

        [[0.456]],

        [[0.406]]]).cuda()

        ds= torch.tensor(
        [[[0.229]],

        [[0.224]],

        [[0.225]]]).cuda()
        '''
        dm= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)

        ds= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)   
    elif args.dataset=="CelebA":
        dm= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)

        ds= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)   
    else:
        dm = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_mean'), **setup)[:, None, None]
        ds = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_std'), **setup)[:, None, None]
    #print(dm,ds)
    #exit()
    #print(dm.size(),ds.size())
    #exit()
    #model = nn.DataParallel(model)
    model.to(**setup)
    model.eval()

    if args.optim == 'GIA-L':
        config = dict(signed=args.signed,
                      cost_fn=args.cost_fn,
                      indices=args.indices,
                      weights=args.weights,
                      lr=args.lr if args.lr is not None else 0.1,
                      optim='adam',
                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=args.tv,
                      bn_stat=args.bn_stat,
                      image_norm=args.image_norm,
                      z_norm=args.z_norm,
                      group_lazy=args.group_lazy,
                      init=args.init,
                      lr_decay=False,
                      dataset=args.dataset,
                      r_patch=args.r_patch,
                      generative_model=args.generative_model,
                      gen_dataset=args.gen_dataset,
                      giml=args.giml,
                      gias=args.gias,
                      gias_lr=args.gias_lr,
                      gias_iterations=args.gias_iterations,
                      r_april=args.r_april,
                      )
    elif args.optim == 'yin':

        config = dict(signed=args.signed,
                      cost_fn=args.cost_fn,
                      indices=args.indices,
                      weights=args.weights,
                      lr=args.lr if args.lr is not None else 0.1,
                      optim='adam',

                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=1e-4,

                      bn_stat=1e-4,#0.1,1e-4
                      image_norm=1e-4,#1e-4,
                      z_norm=args.z_norm,
                      group_lazy=1e-4,#0.01,1e-4
                      init=args.init,
                      lr_decay=True,
                      dataset=args.dataset,
                      r_patch=args.r_patch,
                      generative_model='',
                      gen_dataset='',
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=args.gias_iterations,
                      r_april=args.r_april,
                      )
    elif args.optim == 'gradvit':

        config = dict(signed=args.signed,
                      cost_fn='l2',
                      indices=args.indices,
                      weights=args.weights,
                      lr=args.lr if args.lr is not None else 0.1,
                      optim='adam',

                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=1e-4,#1e-4,
                      bn_stat=0,#0.1,1e-4
                      r_patch=0,#-2
                      image_norm=0,#1e-6,#1e-6,
                      z_norm=args.z_norm,
                      group_lazy=0,#1e-2,#1e-2,#0.01,
                      init=args.init,
                      lr_decay=True,
                      dataset=args.dataset,
                      
                      generative_model='',
                      gen_dataset='',
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=args.gias_iterations,
                      r_april=args.r_april,
                      )
    elif args.optim == 'GIA-O':
        config = dict(signed=args.signed,
                      cost_fn=args.cost_fn,
                      indices=args.indices,
                      weights=args.weights,
                      lr=args.lr if args.lr is not None else 0.1,
                      optim='adam',
                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=args.tv,
                      bn_stat=args.bn_stat,
                      image_norm=args.image_norm,
                      z_norm=args.z_norm,
                      group_lazy=args.group_lazy,
                      init=args.init,
                      lr_decay=True,
                      dataset=args.dataset,
                      r_patch=args.r_patch,
                      generative_model='',
                      gen_dataset='',
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=0,
                      r_april=args.r_april,
                      ) 
    elif args.optim == 'gen':
        config = dict(signed=args.signed,
                      cost_fn=args.cost_fn,
                      indices=args.indices,
                      weights=args.weights,
                      lr=args.lr if args.lr is not None else 0.1,
                      optim='adam',
                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=args.tv,
                      bn_stat=args.bn_stat,
                      image_norm=args.image_norm,
                      z_norm=args.z_norm,
                      group_lazy=args.group_lazy,
                      init=args.init,
                      lr_decay=True,
                      dataset=args.dataset,
                      r_patch=args.r_patch,
                      generative_model=args.generative_model,
                      gen_dataset=args.gen_dataset,
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=0,
                      r_april=args.r_april,
                      )
    elif args.optim == 'april':

        config = dict(signed=args.signed,
                      cost_fn='l2',
                      indices=args.indices,
                      weights=args.weights,
                      lr=args.lr if args.lr is not None else 0.1,
                      optim='adam',
                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=args.tv,
                      bn_stat=args.bn_stat,
                      image_norm=args.image_norm,
                      z_norm=args.z_norm,
                      group_lazy=args.group_lazy,
                      init=args.init,
                      lr_decay=False,
                      dataset=args.dataset,
                      r_patch=args.r_patch,
                      generative_model='',
                      gen_dataset='',
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=0,
                      r_april=1e-2,
                      )      
    elif args.optim == 'zhu':
        config = dict(signed=False,
                      cost_fn='l2',
                      indices='def',
                      weights='equal',
                      lr=args.lr if args.lr is not None else 1.0,
                      optim='LBFGS',
                      restarts=args.restarts,
                      max_iterations=500,
                      total_variation=args.tv,
                      init=args.init,
                      lr_decay=False,
                      bn_stat= args.bn_stat,
                      image_norm= args.image_norm,
                      z_norm = args.z_norm,
                      group_lazy= args.group_lazy,
                      dataset=args.dataset,
                      generative_model='',
                      gen_dataset='',
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=0,
                      r_patch=args.r_patch,
                      r_april=args.r_april,
                      )
    elif args.optim == 'wang':
        config = dict(signed=args.signed,
                      cost_fn='l2',
                      indices='def',
                      weights='equal',
                      lr=args.lr if args.lr is not None else 1.0,
                      optim='adam',
                      restarts=args.restarts,
                      max_iterations=args.max_iterations,
                      total_variation=args.tv,
                      init=args.init,
                      lr_decay=False,
                      bn_stat= args.bn_stat,
                      image_norm= args.image_norm,
                      z_norm = args.z_norm,
                      group_lazy= args.group_lazy,
                      dataset=args.dataset,
                      generative_model='',
                      gen_dataset='',
                      giml=False,
                      gias=False,
                      gias_lr=0.0,
                      gias_iterations=0,
                      r_patch=args.r_patch,
                      r_april=args.r_april,
                      )                  
        # psnr list
    psnrs = []


    config_comp = config.copy()
    config_comp['optim'] = args.optim
    config_comp['dataset'] = args.dataset
    config_comp['model'] = args.model
    config_comp['trained'] = args.trained_model
    config_comp['num_exp'] = args.num_exp
    config_comp['num_images'] = args.num_images
    config_comp['bn_stat'] = args.bn_stat
    config_comp['image_norm'] = args.image_norm
    config_comp['z_norm'] = args.z_norm
    config_comp['group_lazy'] = args.group_lazy
    config_comp['checkpoint_path'] = args.checkpoint_path
    config_comp['accumulation'] = args.accumulation
    config_comp['batch_size'] = args.batch_size
    config_comp['local_lr'] = args.trained_model
    config_hash = hashlib.md5(json.dumps(config_comp, sort_keys=True).encode()).hexdigest()

    print(config_comp)

    os.makedirs(args.table_path, exist_ok=True)
    os.makedirs(os.path.join(args.table_path, f'{config_hash}'), exist_ok=True)
    os.makedirs(args.result_path, exist_ok=True)
    os.makedirs(os.path.join(args.result_path, f'{config_hash}'), exist_ok=True)

    G = None
    if args.checkpoint_path:
        with open(args.checkpoint_path, 'rb') as f:
            G, _ = pickle.load(f)
            G = G.requires_grad_(True).to(setup['device'])


    target_id = args.target_id
    for i in range(args.num_exp):
        ''''''
        target_id = args.target_id + i * args.num_images
        tid_list = []
        if args.num_images == 1:
            if args.dataset.startswith('I'):
                ground_truth, labels = validloader.dataset[target_id]
            else:
                ground_truth, labels = trainloader.dataset[target_id]
            ground_truth, labels = ground_truth.unsqueeze(0).to(**setup), torch.as_tensor((labels,), device=setup['device'])
            target_id_ = target_id + 1
            print("loaded img %d" % (target_id_ - 1))
            tid_list.append(target_id_ - 1)
        else:
            ground_truth, labels = [], []
            target_id_ = target_id
            
            while len(labels) < args.num_images:

                if args.dataset.startswith('I'):
                    img, label = validloader.dataset[target_id_]
                else:
                    img, label = trainloader.dataset[target_id_]
                target_id_ += 1

                if (label not in labels):
                #print("loaded img %d" % (target_id_ - 1))
                    labels.append(torch.as_tensor((label,), device=setup['device']))
                    ground_truth.append(img.to(**setup))
                    tid_list.append(target_id_ - 1)
                

            ground_truth = torch.stack(ground_truth)
            labels = torch.cat(labels)
        #print(ground_truth.shape)
        img_shape = (ground_truth.shape[1], ground_truth.shape[2], ground_truth.shape[3])
        
        # print(labels)

        # Run reconstruction
        if config['bn_stat'] > 0 and args.accumulation == 0:
            bn_layers = []
            for module in model.modules():
                if isinstance(module, nn.BatchNorm2d):
                    bn_layers.append(inversefed.BNStatisticsHook(module))

        useFL=False
        if args.accumulation == 0:
            #print(model(ground_truth))

            #landscape(model,ground_truth,labels,args,filename,100)
            #if args.model not in ["Swin"]: 
            #    calculate_jacobian_norm(model,ground_truth,labels,args,filename,100)
            target_loss, _, _ = loss_fn(model(ground_truth), labels)
            #print(model(ground_truth))
            #print(ground_truth.size())
            #print(target_loss)
            input_gradient = torch.autograd.grad(target_loss, model.parameters())
            input_gradient = [grad.detach() for grad in input_gradient]
            #print(input_gradient)
            #exit()
            bn_prior = []
            if config['bn_stat'] > 0:
                for idx, mod in enumerate(bn_layers):
                    mean_var = mod.mean_var[0].detach(), mod.mean_var[1].detach()
                    #print(mean_var)
                    bn_prior.append(mean_var)
            #print(bn_prior)
            # with open(f'exp_{i}_bn_prior.pkl', 'wb') as f:
            #     pickle.dump(bn_prior, f)

            rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=args.num_images, bn_prior=bn_prior, G=G,dev=args.device)

            if G is None:
                G = rec_machine.G

            output, stats = rec_machine.reconstruct(input_gradient, labels, img_shape=img_shape, dryrun=args.dryrun)
        elif args.accumulation >0 and args.useFL:
            #TODO add args
            num_clients=args.clients
            rounds=args.rounds

            local_gradient_steps = args.accumulation
            local_lr = args.local_lr
            batch_size = args.batch_size
            # Init models
            

            from copy import deepcopy
            #testP=model.parameters()
            gmodel=deepcopy(model)
            gpara=gmodel.state_dict()
            gmodel.to(setup['device'])
            print("Strat FL")
            from differential_privacy.privacy_accountant.pytorch import accountant
            priv_accountant = accountant.GaussianMomentsAccountant(num_clients)
            useDP=0
            filename+="_Defence:"+str(useDP)
            for round in range(rounds):

                if round == 0:
                    arr = np.arange(1,num_clients)
                    np.random.shuffle(arr)
                    selected =  arr[:int(num_clients * args.sample)-1]
                    selected=list(selected)
                    selected=[0]+selected
                else: 

                    arr = np.arange(1,num_clients)
                    np.random.shuffle(arr)
                    selected =  arr[:int(num_clients * args.sample)]
                    selected=list(selected)
                    #selected=[0]+selected
                
                '''
                arr = np.arange(num_clients)
                np.random.shuffle(arr)
                selected = arr[:int(num_clients * args.sample)]
                '''
                
                #print(selected)
                print("in comm round:" + str(round))
                # all the clients are selected
                # init of clients is different 
                
                start_id = args.target_id + i *num_clients* args.num_images
                '''
                from inversefed.nn import MetaMonkey
                model=MetaMonkey(model)
                target_loss, _, _ = loss_fn(model(ground_truth,model.parameters), labels)
                input_gradient = torch.autograd.grad(target_loss, model.parameters.values())
                input_parameters = [grad.detach() for grad in input_gradient]
                '''
                '''
                target_loss, _, _ = loss_fn(model(ground_truth), labels)
                input_gradient = torch.autograd.grad(target_loss, model.parameters())
                input_parameters = [grad.detach() for grad in input_gradient]
                '''
                

                if useDP==0:
                    para,update_0=fedavg.local_train_net_no(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn,vit_num=args.vit_num)
                elif useDP==1:
                    dp_config={
                    'totalNum':args.num_images,
                    'noise_multiplier':0.65,#for CDP 3; for LDP 0.65
                    'max_grad_norm':1.0,
                    'delta':1e-5,
                    'sensitivity':1
                    }
                    para,update_0=fedavg.local_train_net_CDPL(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn,priv_accountant=priv_accountant,dp_config=dp_config)
                elif useDP==2:
                    dp_config={
                    'totalNum':args.num_images,
                    'noise_multiplier':0.65,#for CDP 3; for LDP 0.65
                    'max_grad_norm':1.0,
                    'delta':1e-5,
                    'sensitivity':1
                    }
                    para,update_0=fedavg.local_train_net_CDPM(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn,priv_accountant=priv_accountant,dp_config=dp_config)
                elif useDP==3:
                    dp_config={
                    'totalNum':args.num_images,
                    'noise_multiplier':0.65,#for CDP 3; for LDP 0.65
                    'max_grad_norm':1.0,
                    'delta':1e-5,
                    'sensitivity':1
                    }
                    para,update_0=fedavg.local_train_net_LDPL(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn,priv_accountant=priv_accountant,dp_config=dp_config)
                elif useDP==4:
                    dp_config={
                    'totalNum':args.num_images,
                    'noise_multiplier':0.65,#for CDP 3; for LDP 0.65
                    'max_grad_norm':1.0,
                    'delta':1e-5,
                    'sensitivity':1
                    }
                    para,update_0=fedavg.local_train_net_LDPM(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn,priv_accountant=priv_accountant,dp_config=dp_config)
                #para,update_0=fedavg.local_train_net(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn,useDP=dp_config['useDP'],priv_accountant=priv_accountant,dp_config=dp_config)
                
                #para,update_0=fedavg.local_train_net(gmodel,selected=selected,validloader=trainloader,target_id=start_id,num_images=args.num_images,lr=local_lr,local_steps=local_gradient_steps, batch_size=batch_size,device=setup['device'],loss_fn=loss_fn)
                input_parameters=update_0[0]
                #print(input_parameters)
                #input_parameters = [p.detach() for p in input_parameters]
                #print(ground_truth.size(), labels)
                
                # iid
                #print(update_0)
                
                gmodel.load_state_dict(para)
                
                # TODO test
                test_acc= fedavg.compute_accuracy(gmodel, validloader, get_confusion_matrix=False, device=setup['device'])
                print('>> Global Model Test accuracy: %f' % test_acc)
                
                
                
                #print( [ p1-p2 for p1,p2 in zip(testP,model.parameters()) ])
                #nb=[torch.norm(p1.detach().clone()-p2.detach().clone()) for (n1,p1),(n2,p2) in zip(model.named_parameters(),gmodel.named_parameters())]

                if round==0:
                    attmodel=deepcopy(model)
                    
                    attmodel.to(setup['device'])
                    if args.vit_num != 0: 
                        if args.vit_num>=batch_size:
                            rec_machine = inversefed.FedAvgReconstructor(attmodel, (dm, ds), local_gradient_steps,
                                                                        local_lr, config,
                                                                        num_images=args.vit_num, use_updates=True,
                                                                        batch_size=batch_size,bn_prior=update_0[2])
                        else:
                            args.vit_num=batch_size
                            rec_machine = inversefed.FedAvgReconstructor(attmodel, (dm, ds), local_gradient_steps,
                                                                        local_lr, config,
                                                                        num_images=args.vit_num, use_updates=True,
                                                                        batch_size=batch_size,bn_prior=update_0[2])
                    else:
                        rec_machine = inversefed.FedAvgReconstructor(attmodel, (dm, ds), local_gradient_steps,
                                                                        local_lr, config,
                                                                        num_images=args.num_images, use_updates=True,
                                                                        batch_size=batch_size,bn_prior=update_0[2])
                    if G is None:
                        if rec_machine.generative_model_name in ['stylegan2']:
                            G = rec_machine.G_synthesis
                        else:
                            G = rec_machine.G
                    labels=[]
                    ground_truth=[]
                    tid_list=[]
                    
                    for idx in update_0[1]:
                        img, label = trainloader.dataset[idx]
                        labels.append(torch.as_tensor((label,), device=setup['device']))
                        ground_truth.append(img.to(**setup))
                        tid_list.append(idx)
                    labels = torch.cat(labels)
                    ground_truth = torch.stack(ground_truth)
                    #print(labels)
                    output, stats = rec_machine.reconstruct(input_parameters, labels, img_shape=img_shape, dryrun=args.dryrun)
                    #print(filename)
                ''''''
            
            print("End FL")
            #exit()
            

        else:
            
            local_gradient_steps = args.accumulation
            local_lr = args.local_lr
            batch_size = args.batch_size
            input_parameters = inversefed.reconstruction_algorithms.loss_steps(model, ground_truth,
                                                                               labels,
                                                                               lr=local_lr,
                                                                               local_steps=local_gradient_steps, use_updates=True, batch_size=batch_size)
            input_parameters = [p.detach() for p in input_parameters]

            rec_machine = inversefed.FedAvgReconstructor(model, (dm, ds), local_gradient_steps,
                                                         local_lr, config,
                                                         num_images=args.num_images, use_updates=True,
                                                         batch_size=batch_size)
            if G is None:
                if rec_machine.generative_model_name in ['stylegan2']:
                    G = rec_machine.G_synthesis
                else:
                    G = rec_machine.G
            output, stats = rec_machine.reconstruct(input_parameters, labels, img_shape=img_shape, dryrun=args.dryrun)

        # Compute stats and save to a table:
        output_den = torch.clamp(output * ds + dm, 0, 1)
        ground_truth_den = torch.clamp(ground_truth * ds + dm, 0, 1)
        feat_mse = (model(output) - model(ground_truth)).pow(2).mean().item()
        test_mse = (output_den - ground_truth_den).pow(2).mean().item()
        test_psnr = inversefed.metrics.psnr(output_den, ground_truth_den, factor=1)
        test_lpips = inversefed.metrics.lpips_loss(output_den, ground_truth_den)
        #print(test_lpips)
        print(f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} | LPIPS: {test_lpips:2.4f} ")
        note=f"{args.pth}, {args.target_id} : Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} | LPIPS: {test_lpips:2.4f} "
        
        
        inversefed.utils.save_to_table(os.path.join(args.table_path, f'{config_hash}'), name=f'mul_exp_{args.name}', dryrun=args.dryrun,

                                       config_hash=config_hash,
                                       model=args.model,
                                       dataset=args.dataset,
                                       trained=args.trained_model,
                                       restarts=args.restarts,
                                       OPTIM=args.optim,
                                       cost_fn=args.cost_fn,
                                       indices=args.indices,
                                       weights=args.weights,
                                       init=args.init,
                                       tv=args.tv,

                                       rec_loss=stats["opt"],
                                       psnr=test_psnr,
                                       test_mse=test_mse,
                                       feat_mse=feat_mse,

                                       target_id=target_id,
                                       seed=model_seed,
                                       epochs=defs.epochs,
                                    #    val_acc=training_stats["valid_" + name][-1],
                                       )


        # Save the resulting image
        if args.save_image and not args.dryrun:

            output_denormalized = torch.clamp(output * ds + dm, 0, 1)
            
            for j in range(args.num_images) :

                torchvision.utils.save_image(output_denormalized[j:j + 1, ...], os.path.join(args.result_path, f'{config_hash}', f'{tid_list[j]}'+"_"+str(args.pth)+'.png'))
                torchvision.utils.save_image(ground_truth_den[j:j + 1, ...], os.path.join(args.result_path, f'{config_hash}', f'{tid_list[j]}_gt.png'))
            ''''''
            from PIL import Image
            


            padding = 10
            
            if args.vit_num>0:

                filename=filename.replace(':','-')
                os.makedirs(os.path.join(args.result_path, filename), exist_ok=True)
                big_image = Image.new('RGB', (args.vit_num * output_denormalized.shape[2] + (args.vit_num - 1) * padding, output_denormalized.shape[3] * 2 + padding), (255, 255, 255))


                for j in range(args.vit_num):

                    torchvision.utils.save_image(output_denormalized[j:j + 1, ...], os.path.join(args.result_path, f'{filename}', f'{tid_list[j]}.png'))
                    torchvision.utils.save_image(ground_truth_den[j:j + 1, ...], os.path.join(args.result_path, f'{filename}', f'{tid_list[j]}_gt.png'))
                    output_image = torchvision.transforms.ToPILImage()(output_denormalized[j:j + 1, ...].squeeze())
                    big_image.paste(output_image, (j * (output_image.width + padding), 0))
                    big_image.paste((255, 255, 255), (j * (output_image.width + padding) + output_image.width, 0, (j+1) * (output_image.width + padding), output_image.height))


                    gt_image = torchvision.transforms.ToPILImage()(ground_truth_den[j:j + 1, ...].squeeze())
                    big_image.paste(gt_image, (j * (gt_image.width + padding), output_image.height + padding))
                    big_image.paste((255, 255, 255), (j * (gt_image.width + padding) + gt_image.width, output_image.height + padding, (j+1) * (gt_image.width + padding), output_image.height + padding + gt_image.height))


                
                big_image.save(os.path.join(args.result_path, filename, 'big_image.png'))
            else:

                filename=filename.replace(':','-')
                os.makedirs(os.path.join(args.result_path, filename), exist_ok=True)
                file_path=os.path.join(args.result_path, filename, 'results.txt')
        
                with open(file_path, 'a+') as file:

                    file.seek(0)
                    is_empty = not bool(file.read())


                    file.seek(0, 2)


                    if not is_empty:
                        file.write('\n')


                    file.write(note)
                big_image = Image.new('RGB', (args.num_images * output_denormalized.shape[2] + (args.num_images - 1) * padding, output_denormalized.shape[3] * 2 + padding), (255, 255, 255))


                for j in range(args.num_images):
                    torchvision.utils.save_image(output_denormalized[j:j + 1, ...], os.path.join(args.result_path, f'{filename}', f'{tid_list[j]}'+"_"+str(args.pth)+'.png'))
                    torchvision.utils.save_image(ground_truth_den[j:j + 1, ...], os.path.join(args.result_path, f'{filename}', f'{tid_list[j]}_gt'+"_"+str(args.pth)+'.png'))

                    output_image = torchvision.transforms.ToPILImage()(output_denormalized[j:j + 1, ...].squeeze())
                    big_image.paste(output_image, (j * (output_image.width + padding), 0))
                    big_image.paste((255, 255, 255), (j * (output_image.width + padding) + output_image.width, 0, (j+1) * (output_image.width + padding), output_image.height))

                    gt_image = torchvision.transforms.ToPILImage()(ground_truth_den[j:j + 1, ...].squeeze())
                    big_image.paste(gt_image, (j * (gt_image.width + padding), output_image.height + padding))
                    big_image.paste((255, 255, 255), (j * (gt_image.width + padding) + gt_image.width, output_image.height + padding, (j+1) * (gt_image.width + padding), output_image.height + padding + gt_image.height))

                
                big_image.save(os.path.join(args.result_path, filename, 'big_image.png'))

        # Save psnr values
        psnrs.append(test_psnr)
        inversefed.utils.save_to_table(os.path.join(args.table_path, f'{config_hash}'), name='psnrs', dryrun=args.dryrun, target_id=target_id, psnr=test_psnr)

        # Update target id
        target_id = target_id_


    # psnr statistics
    psnrs = np.nan_to_num(np.array(psnrs))
    psnr_mean = psnrs.mean()
    psnr_std = np.std(psnrs)
    psnr_max = psnrs.max()
    psnr_min = psnrs.min()
    psnr_median = np.median(psnrs)
    timing = datetime.timedelta(seconds=time.time() - start_time)
    inversefed.utils.save_to_table(os.path.join(args.table_path, f'{config_hash}'), name='psnr_stats', dryrun=args.dryrun,
                                   number_of_samples=len(psnrs),
                                   timing=str(timing),
                                   mean=psnr_mean,
                                   std=psnr_std,
                                   max=psnr_max,
                                   min=psnr_min,
                                   median=psnr_median)

    # Print final timestamp
    print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p"))
    print('---------------------------------------------------')
    print(f'Finished computations with time: {str(datetime.timedelta(seconds=time.time() - start_time))}')
    print('-------------Job finished.-------------------------')

