import os
import re
import matplotlib.pyplot as plt
import glob
import torch
import numpy as np
from matplotlib import rcParams
from matplotlib.ticker import FuncFormatter
import math
config = {
    "font.family":'Times New Roman',
    "font.serif": ['Times New Roman'],
    "mathtext.fontset":'stix',
}
rcParams.update(config)
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.size'] = 14

plt.rcParams['font.weight'] = 'bold'

def extract_data_from_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()


    data_groups = {}

    for line in lines:
        parts = line.strip().split(' : ')
        label_parts = parts[0].strip().split(', ')

        label=int(label_parts[0])
        #print(parts[1])
        values = parts[1].strip().split('|')
        metrics = [float(value.split(':')[1]) for value in values]

        if label in data_groups:
            data_groups[label].append(metrics)
        else:
            data_groups[label] = [metrics]

    rec_loss, mse, psnr, fmse, lpips, labels=[],[],[],[],[],[]
    average_metrics = {}
    for label, metric_list in data_groups.items():
        num_experiments = len(metric_list)
        if num_experiments>=0:
            v1,v2,v3,v4,v5=0,0,0,0,0
            
            for i in range(num_experiments):
                v1+=metric_list[i][0]/num_experiments
                v2+=metric_list[i][1]/num_experiments
                v3+=metric_list[i][2]/num_experiments
                v4+=metric_list[i][3]/num_experiments
                v5+=metric_list[i][4]/num_experiments
            rec_loss.append(v1)
            mse.append(v2)
            psnr.append(v3)
            fmse.append(v4)
            lpips.append(v5)
            labels.append(label)
    return rec_loss, mse, psnr, fmse, lpips, labels

def plot_indicators_avg(file_path):
    rec_loss, mse, psnr, fmse, lpips, labels = extract_data_from_file(file_path)


    def calculate_averages(data,num):
        averages = []
        count = 0
        total = 0
        for value in data:
            total += value
            count += 1
            if count == num:
                averages.append(total / num)
                count = 0
                total = 0
        if count > 0:
            averages.append(total / count)
        return averages
    # sample for every num points
    def sample(data,num):
        averages = []
        for i in range(0,len(data),num):
            averages.append(data[i])
        return averages

    rec_loss_avg = rec_loss
    mse_avg = mse
    psnr_avg = psnr
    fmse_avg = fmse
    lpips_avg=lpips

    n=1

    rec_loss_avg = sample(rec_loss_avg,n)
    mse_avg = sample(mse_avg,n)
    psnr_avg = sample(psnr_avg,n)
    fmse_avg = sample(fmse_avg,n)
    lpips_avg = sample(lpips_avg,n)
    [rec_loss_avg,mse_avg,psnr_avg,fmse_avg,lpips_avg]
    return psnr_avg

def format_func(value, tick_number):

    return f'{value:.1g}'


if __name__ == "__main__":
    rec=[]
    file_names = [f for f in os.listdir(".") if f.startswith("results_") and f.endswith(".txt")]
    file_names.sort()
    print(file_names)
    for file_name in file_names:
        rec.append(plot_indicators_avg(file_name))
    norm=[]
    def sample(data,num):
        averages = []
        for i in range(0,len(data),num):
            averages.append(data[i])
        return averages
    def average_by_groups(lst, group_size):
        result = []
        for i in range(0, len(lst), group_size):
            group = lst[i:i+group_size]

            group_average = sum(group) / len(group)  

            result.append(group_average)

        if len(lst) % group_size != 0:
            last_group = lst[-(len(lst) % group_size):]
            last_group_average = sum(last_group) / len(last_group)
            result.append(last_group_average)
        return result
    file_prefix = "norms_"
    file_extension = ".txt"
    folder_paths = [folder for folder in os.listdir() if os.path.isdir(folder)]
    folder_paths.sort()
    print(folder_paths)
    for folder_path in folder_paths:

        file_paths = glob.glob(os.path.join(folder_path, f"{file_prefix}*{file_extension}"))
        
        if not file_paths:
            print("No dound")
            continue

        time_step_values = {}
        time_step_values_all = {}
        averaged_values_all = {}
        for file_path in file_paths:

            file_name = os.path.basename(file_path)
            id, time_step = re.match(r'norms_(\d+)_(\d+)\.txt', file_name).groups()
            id, time_step = int(id), int(time_step)
            with open(file_path, 'r') as file:
                content = file.read()

                values = [float(val) for val in re.findall(r'tensor\((\d+\.\d+)', content)]

                if time_step not in time_step_values_all:
                    time_step_values_all[time_step] = {}

                if id not in time_step_values_all[time_step].keys():
                    time_step_values_all[time_step][id] = values
                else:
                    #print(time_step_values_all[time_step][id])
                    time_step_values_all[time_step][id] += values
                    #print(time_step_values_all[time_step][id])
                    #exit()
        for time_step, id_values in time_step_values_all.items():
            averaged_values_all[time_step] = []
            for id, values in id_values.items():

                if averaged_values_all[time_step] == []:
                    averaged_values_all[time_step] = [0]*len(values)
                averaged_values_all[time_step] = [val1+ val2 / len(values) for val1,val2 in zip(averaged_values_all[time_step],values)]
            
            vals = averaged_values_all[time_step]
            
            weights= [1+1 / (i + 1) for i in range(len(vals))]

            averaged_values_all[time_step] = [val*w for val,w in zip(vals,weights)]



        time_step_values = averaged_values_all
        

        
        time_step_average = {}
        for time_step, values in time_step_values.items():

            if len(values)==0:
                time_step_average[time_step] = 0
            else:

                time_step_average[time_step] = 1.0/(sum(values) / len(values))
        

        max_val=max(list(time_step_average.values()))
        min_val=min(list(time_step_average.values()))

        results=sample([(val-min_val)/(max_val-min_val) for val in time_step_average.values()],1)

        norm.append(results)
    test=[]
    file_names = [f for f in os.listdir(".") if f.startswith("test_") and f.endswith(".txt")]
    file_names.sort()
    print(file_names)
    for file_name in file_names:
        with open(file_name, 'r') as file:
            lines = file.readlines()
            lines=lines[:100]
            acc=[float(line.strip()) for line in lines ]
            test.append(acc)

    rec=[sample(r,5) for r in rec]
    norm=[average_by_groups(n,5) for n in norm]

    test=[sample(t,5) for t in test]
    splits=[3.4/20,6.3/20,4.8/20,4/20,1.8/20,5.7/20,4.5/20,4.7/20,4.9/20,1]

    idx=[6,7,8,4,5,3,0,1,2]
    rec=[rec[i] for i in idx]
    norm=[norm[i] for i in idx]
    test=[test[i] for i in idx]
    splits=[splits[i] for i in idx]

    names=["$\mathrm{Vgg-11}$","$\mathrm{Vgg-16}$","$\mathrm{Vgg-19}$","$\mathrm{ResNet-18}$","$\mathrm{ResNet-50}$","$\mathrm{ResNet-152}$","$\mathrm{DenseNet-121}$","$\mathrm{GoogleNet}$","$\mathrm{Inception-v3}$"]

    from mpl_toolkits.axes_grid1.inset_locator import inset_axes    

    x = np.arange(0, 100, 5) 

    fig, axes = plt.subplots(3, 3, figsize=(20, 20),dpi=600)
    fig.subplots_adjust(wspace=0.03, hspace=0.15) 

    for i, ax in enumerate(axes.flat):

        ax.plot(x, rec[i], 's--', color='b',markersize=8)
        ax.set_ylim(0, 20)
        if i in [0,3,6]:
            ax.set_ylabel('PSNR', color='b',font={'size':26})

            ax.tick_params(axis='y', labelcolor='b',labelsize=24)
            
        else:
            ax.set_yticks([]) 
            ax.set_ylabel('')  
        if i in [6,7,8]:
            ax.tick_params(axis='x',labelsize=24)
            ax.set_xlabel('Rounds',font={'size':28})
            ax.set_xticks([0, 25,50,75]) 
        else:
            ax.set_xticks([])  
            ax.set_xlabel('')  
        
        x_split = splits[i]


        ax.axhspan(ymin=0, ymax=100, xmin=0, xmax=x_split, color='lightgrey')

        ax.axhspan(ymin=0, ymax=100, xmin=x_split, xmax=5, color='white')

        
        ax2 = ax.twinx()

        ax2.plot(x, norm[i], '*-', color='g',markersize=8)
        #1,6
        if i in [2,5,8]:
            if i==1 or i==6:
                ax2.set_yticks([0, 0.1, 0.2]) 
            ax2.set_ylabel('IGSA', color='g',font={'size':26})
            ax2.tick_params(axis='y', labelcolor='g',labelsize=24)
        else:
            ax2.set_yticks([]) 
            ax2.set_ylabel('') 
        

        ax3 = ax.twinx()
        ax3.plot(x, test[i], 'o-.', color='r',markersize=8)
        ax3.set_ylim(0, 100)
        if i in [2,5,8]:
            ax3.spines['right'].set_position(('axes', 1.20)) 
            ax3.spines['right'].set(linestyle='-.')
            
            ax3.set_ylabel('Test Accuracy (%)', color='r',font={'size':28})
            ax3.tick_params(axis='y', labelcolor='r',labelsize=24)
        else:
            ax3.set_yticks([])  
            ax3.set_ylabel('')  
        

        ax.set_title(names[i],font={'size':30})


    plt.savefig('IGSA_test.pdf', format='pdf')

