"""Run reconstruction in a terminal prompt.
Optional arguments can be found in inversefed/options.py

This CLI can recover the baseline experiments.
"""

import torch
import torchvision
import torch.nn as nn

import numpy as np

import inversefed
torch.backends.cudnn.benchmark = inversefed.consts.BENCHMARK

from collections import defaultdict
import datetime
import time
import os
import json
import hashlib

import argparse
import random
random.seed(42)
nclass_dict = {'I32': 1000, 'I64': 1000, 'I128': 1000,'I256':1000,
               'CIFAR10': 10, 'CIFAR100': 100, 'CA': 8, 'ImageNet':1000,
               'FFHQ': 10, 'FFHQ64': 10, 'FFHQ128': 10,'MNIST':10,
               'MNIST_GRAY':10,'CelebA':2,
               }
# Parse input arguments
parser = argparse.ArgumentParser(description='Reconstruct some image from a trained model.')
parser.add_argument('--model', default='0', type=str, help='model')
parser.add_argument('--dataset', default='0', type=str, help='dataset')

parser.add_argument('--device', default='0', type=str, help='gpu')

parser.add_argument('--result_path', default='results_IGSA/', type=str, help='pthRound')
args = parser.parse_args()

args.result_path='results_IGSA/'
# Parse training strategy
defs = inversefed.training_strategy('conservative')
defs.epochs = 1

def calculate_jacobian_norm(model, data, label, args,target_id,pth,filename,K, p=1e-3):
    filename=filename.replace(':','-')
    model.eval()
    
    outputs = model(data)

    loss = torch.nn.functional.cross_entropy(outputs, label)
    gradient = torch.autograd.grad(loss, model.parameters())
    list_gradient = [grad.detach() for grad in gradient]

    num_layers = len(list(model.parameters()))

    mean_norms = [0] * num_layers
    
    
    for k in range(K):

        noise = torch.randn_like(data) * p
        perturbed_data = data + noise.to(data.device)

        
        outputs_k = model(perturbed_data)
        loss_k = torch.nn.functional.cross_entropy(outputs_k, label)
        gradient_k = torch.autograd.grad(loss_k, model.parameters())
        list_gradient_k = [grad.detach() for grad in gradient_k]
        epsilon = 1e-8
        
        diffs = [(grad_k - grad) for grad_k,grad in zip(list_gradient_k,list_gradient)]
        norms = [torch.norm(diff, 2) for diff in diffs]

  
        mean_norms = [mean_norm + norm for mean_norm, norm in zip(mean_norms, norms)]
    

    mean_norms = [mean_norm.cpu() / K for mean_norm in mean_norms]
    os.makedirs(os.path.join(args.result_path, filename), exist_ok=True)
    txt_path=os.path.join(args.result_path, filename, 'norms'+'_'+str(target_id)+'_'+str(pth)+'.txt')
        
    with open(txt_path, 'a+') as f:
        for norm in mean_norms:
                f.write(str(norm) + '\n')

    return mean_norms

if __name__ == "__main__":

    setup = inversefed.utils.system_startup(args)

    filename="model:"+args.model+"_dataset:"+args.dataset+"_device"+args.device


    if args.dataset.startswith('I'):
        
        loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs, data_path='~/data')
    else:
        loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs, data_path='~/data')

    
    if args.dataset.startswith('FFHQ'):
        dm = torch.as_tensor(getattr(inversefed.consts, f'cifar10_mean'), **setup)[:, None, None]
        ds = torch.as_tensor(getattr(inversefed.consts, f'cifar10_std'), **setup)[:, None, None]
    elif args.dataset.startswith('I'):

        '''
        dm= torch.tensor(
        [[[0.485]],

        [[0.456]],

        [[0.406]]]).cuda()

        ds= torch.tensor(
        [[[0.229]],

        [[0.224]],

        [[0.225]]]).cuda()
        '''
        dm= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)

        ds= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)   
    elif args.dataset=="CelebA":
        dm= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)

        ds= torch.tensor(
        [[[0.5]],

        [[0.5]],

        [[0.5]]]).cuda("cuda:"+args.device)   
    else:
        dm = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_mean'), **setup)[:, None, None]
        ds = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_std'), **setup)[:, None, None]

    labels = trainloader.dataset.targets


    ids = []

    for label in range(10):  
        label_indices = [i for i, lbl in enumerate(labels) if lbl == label]
        random_selected_indices = random.sample(label_indices, 5)  
        ids.extend(random_selected_indices)
    
    # config ~\inversefed\nn\models.py for initializing models with pth
    for pth in range(100):
        model, model_seed = inversefed.construct_model(args.model, num_classes=nclass_dict[args.dataset], num_channels=3, seed=0,pth=pth)
        model.to(**setup)
        model.eval()
        
        for target_id in ids:
            ''''''
            tid_list = []
            if args.dataset.startswith('I'):
                ground_truth, labels = validloader.dataset[target_id]
            else:
                ground_truth, labels = trainloader.dataset[target_id]
            ground_truth, labels = ground_truth.unsqueeze(0).to(**setup), torch.as_tensor((labels,), device=setup['device'])
            calculate_jacobian_norm(model,ground_truth,labels,args,target_id,pth,filename,100)