import os
import sys
import pandas as pd
import re
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style='darkgrid')

EVENTS = [
    #'TSStep',
    'KSPSolve',
    'SNESSolve',
    'SNESFunctionEval',
    'SNESJacobianEval',
    'PCSetUp',
    'MatMult',
]

def normalize_hostname(hostname):
    if re.match(r'crusher\d+', hostname):
        return 'Crusher'
    elif re.match(r'[a-z]\d+n\d+', hostname):
        return 'Summit'
    elif re.match(r'lassen\d+', hostname):
        return 'Lassen'
    elif re.match(r'nid\d+', hostname):
        return 'Perlmutter'
    raise RuntimeError(f"Cannot normalize hostname: {hostname}")

def gpus_per_node(hostname):
    table = dict(Crusher=8, Summit=6, Lassen=4, Perlmutter=4)
    return table.get(hostname, 0)

def num_nodes(mpi, gpn, filename):
    mpi_per_gpu = {
        # Crusher
        'schwarz-q2-t20-r2-l2-89099.out': 2,
        'schwarz-q2-t20-r2-l2-89101.out': 2,
        'schwarz-q2-t20-r2-l2-94011.out': 2,
        # Summit
        'schwarz-q2-t20-r2-l2-1952766.out': 2,
        'schwarz-q2-t20-r2-l2-1952897.out': 2,
        'schwarz-q2-t20-r2-l2-1952767.out': 2,
        # Lassen
        'schwarz-q2-t20-r2-l2-3394323.out': 2,
        'schwarz-q2-t20-r2-l2-3394325.out': 2,
        'schwarz-q2-t20-r2-l2-3394384.out': 2,
    }
    nodes = mpi // (mpi_per_gpu.get(os.path.split(filename)[-1], 1) * gpn)
    return nodes
def get_displacements(ll):
    displacements = []

    for disp in ll:
        displacements.append(float(disp.strip().replace('[', '').replace(']', '').replace(',', '')))

    return displacements

def parse_file_content(filename):

    record_list = []
    with open(filename, 'r') as fd:
        for line in fd:
            ll = line.strip().split()
            if line.startswith('Ratel Context:'):
                file_data = {}
                file_data['SNES iterations'] = 0
            elif line.strip().startswith('Hostname'):
                hostname = normalize_hostname(ll[-1])
                file_data['Hostname'] = hostname
            elif line.strip().startswith('Polynomial order'):
                file_data['Order'] = int(ll[-1])
            elif line.strip().startswith('MatType'):
                file_data['MatType'] = ll[-1]
            elif line.strip().startswith('-dm_plex_tps_extent'):
                file_data['Extent'] = ll[1].split(',')[0]
            elif line.strip().startswith('Max displacements'):
                file_data['Displacement'] = get_displacements(ll[2:])
            elif line.strip().startswith('Nonlinear solve'):
                file_data['SNES iterations'] = file_data['SNES iterations'] + int(ll[-1])
            elif line.strip().startswith('Total ranks'):
                mpi = int(ll[-1])
                file_data['MPI'] = mpi
                nodes = num_nodes(mpi, gpus_per_node(hostname), filename)
                file_data['Nodes'] = nodes
                file_data['GPU'] = nodes * gpus_per_node(hostname)
            elif 'Global DoFs' in line:
                file_data['Global DoFs'] = int(ll[-1])
            elif 'Computed strain energy' in line:
                file_data['Computed strain energy'] = float(ll[-1])
            elif ll and ll[0].strip() in EVENTS:
                #print(line)
                #event = line[:16]
                #count = int(line[17:24])
                #count_balance = float(line[24:28])
                time = float(line[29:39])
                #time_balance = float(line[39:43])
                #print(event, count, count_balance, time, time_balance)
                file_data[ll[0].strip()] = time
            if line.startswith("#End of PETSc Option Table entries"):
                record_list.append(file_data)
                file_data = {}

    if file_data:
        print(f"Incomplete records in {filename}; discarding incompletes: {file_data}")
    return record_list

def create_data_frame(files_data):

    df = pd.DataFrame.from_records(files_data)
    pd.set_option('display.expand_frame_repr', False)
    pd.set_option('display.float_format', lambda x: '%.12f' % x)

    return df

def run_alg_perf(filenames):

    #parse files
    files_data = []
    for filename in filenames:
        files_data += parse_file_content(filename)
    #create a dataframe
    df = create_data_frame(files_data)
    return df

def run_plot(df, events, plot, output):
    # print(df.head())
    df['MPI per GPU'] = df['MPI'] // df['GPU']
    for event in events:
        df.drop(df.loc[df['Hostname']=='Lassen'][df['Nodes'] > 1].index, inplace = True)

        df['Efficiency (MDoF/s/GPU)'] = df['Global DoFs'] / (df[event] * df['GPU']) * 1e-6
        fig, ax = plt.subplots(figsize=(10, 6), layout='tight')
        sns.lineplot(
            x=event,
            y='Efficiency (MDoF/s/GPU)',
            style='Nodes',
            size='MPI per GPU',
            palette='colorblind',
            hue='Hostname',
            markers=True,
            sizes=(2, 4),
            alpha=.7,
            ax=ax,
            data=df,
        )
        ax.set_xscale('log')
        ax.set_ylim(bottom=0)
        plt.legend(loc='lower right')
        if output:
            plt.savefig(f"{output}-{event}.svg")
            plt.savefig(f"../figures/{output}-{event}.pdf")
        if plot:
            plt.show()

def run_plot_apply(df, plot, output):
    fig, ax = plt.subplots(figsize=(10, 6), layout='tight')
    df['Efficiency (GDoF/s/GPU)'] = df['Global DoFs'] / (df['MatMult'] * df['GPU']) * 1500 * 1e-9 # 1500 matrix multiplications
    df['MatMult (ms)'] = df['MatMult'] * 1000 / 1500
    sns.lineplot(
        #x='MatMult (ms)',
        x='Global DoFs',
        y='Efficiency (GDoF/s/GPU)',
        style='Order',
        palette='bright',
        hue='MatType',
        markers=True,
        size="Nodes",
        sizes=(2, 4),
        alpha=.7,
        ax=ax,
        data=df,
    )
    ax.set_xscale('log')
    ax.set_ylim(bottom=0)
    plt.legend(loc='lower right')
    if output:
        plt.savefig(f"{output}-apply.svg")
    if plot:
        plt.show()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--plot", help="Plot results", action="store_true")
    parser.add_argument("--plot-apply", help="Plot operator apply", action="store_true")
    parser.add_argument('--event', help="PETSc event to plot", default=None)
    parser.add_argument('--save_data', help="Where to save csv to", default=None)
    parser.add_argument("--output", help='Output file for plotting')
    parser.add_argument("--filenames", help="List of files", nargs="*")
    args = parser.parse_args()
    df = run_alg_perf(args.filenames)
    print(df)
    # df.to_csv(args.save_data + ".csv", sep=',')
    if args.plot or args.output:
        events = args.event.split(',') if args.event else EVENTS
        run_plot(df, events, args.plot, args.output)
    if args.plot_apply:
        run_plot_apply(df, args.plot, args.output)
