from Classic import *
import argparse
def cal_N50(df, node_numbers,N_ratio):
    dfnew=df.sort_values('number',ascending=False)
    number=dfnew.values[0][1]
    row_old = dfnew.values[0][1]
    if len(dfnew.values) ==1:
        return row_old
    for row in dfnew.values[1:]:
        if (number >= node_numbers*N_ratio):
            return row_old
        else:
            number=number+ row[1]
            row_old = row[1]
    return 1

def convert_extension(input_string,threshold, new_extension):
    # Split the input string into name and extension
    name, old_extension = input_string.rsplit('.', 1)
    name = name + "_" + str(threshold)

    # Check if the old extension is "mgf"
    if old_extension.lower() == 'mgf':
        # Concatenate the name with the new extension
        output_string = f"{name}.{new_extension}"
        return output_string
    else:
        # If the input string doesn't end with ".mgf", return an error message or handle it as needed
        return "Invalid input file format"
# def cal_N50(df, node_numbers,N_ratio):
#     dfnew=df.sort_values('number',ascending=False)
#     number=0
#     row_old = dfnew.values[0][1]
#     for row in dfnew.values:
#         if (number >= node_numbers*N_ratio):
#             return row_old
#         else:
#             number=number+ row[1]
#             row_old = row[1]
#     return 1
if __name__ == '__main__':
    #pass arguments
    parser = argparse.ArgumentParser(description='Using realignment method to reconstruct the network')
    parser.add_argument('--input', type=str,required=True, help='input libray name')
    parser.add_argument('--method', type=str, required=True,default="MS2DeepScore", help='realignment method')
    args = parser.parse_args()
    input_lib_name = args.input
    benchmark_method = args.method

    #read libraries from input file
    # with open(input_lib_file,'r') as f:
    #     libraries = f.readlines()

    summary_file_path = "../../data/summary/"+input_lib_name + "_summary.tsv"
    # merged_pairs_file_path = "./data/merged_paris/"+library+"_merged_pairs.tsv"
    cluster_summary_df = pd.read_csv(summary_file_path)
    if benchmark_method == "MS2DeepScore":
        merged_file = "../../results/MS2DeepScore/"+input_lib_name
    else:
        merged_file = "../../data/Network_barebone"+input_lib_name
    print(merged_file)
    threshold_list = [0.5,0.6,0.7,0.8,0.9,0.91,0.92,0.93,0.94,0.95,0.96,0.97,0.98,0.99,0.991,0.992,0.993,0.994,0.995,0.996,0.997]
    N20_list=[]
    score_list=[]
    merged_file_original = "../../data/merged_pairs/"+input_lib_name + "_merged_pairs.tsv"
    original_all_pairs_df = pd.read_csv(merged_file_original, sep='\t')
    G_original = nx.from_pandas_edgelist(original_all_pairs_df, "CLUSTERID1", "CLUSTERID2", "Cosine")
    num_of_nodes = G_original.number_of_nodes()
    print(num_of_nodes)
    for threshold in threshold_list:
        merged_pairs_file_path = merged_file + "_" + str(threshold) + ".tsv"
        print(merged_pairs_file_path)
        all_pairs_df = pd.read_csv(merged_pairs_file_path, sep='\t')
        G_all_pairs = nx.from_pandas_edgelist(all_pairs_df, "CLUSTERID1", "CLUSTERID2", "Cosine")
        print('graph with {} nodes and {} edges'.format(G_all_pairs.number_of_nodes(), G_all_pairs.number_of_edges()))
        print("constructing dic for finger print")
        dic_fp = fingerprint_dic_construct(cluster_summary_df)
        results_df_list=[]
        score_all_pairs_filter_list=[]
        components = [G_all_pairs.subgraph(c).copy() for c in nx.connected_components(G_all_pairs)]
        for component in tqdm(components):
            score_all_pairs_filter_list.append(subgraph_score_dic(component,cluster_summary_df,dic_fp))
        all_pairs_filter_number = [len(x) for x in components]
        df_all_pairs_filter = pd.DataFrame(list(zip(score_all_pairs_filter_list, all_pairs_filter_number)),columns=['score', 'number'])
        results_df_list.append(df_all_pairs_filter)
        # result_file_path = "./results-base/"+library+"_baseline_benchmark.pkl"
        # with open(result_file_path, 'wb') as file:
        #     pickle.dump(results_df_list, file)

        print(np.array([cal_N50(x, num_of_nodes, 0.2) for x in results_df_list]))
        print(np.array([weighted_average(x, 'score', 'number') for x in results_df_list]))
        N20_list.append(np.array([cal_N50(x, num_of_nodes, 0.2) for x in results_df_list])[0])
        score_list.append(np.array([weighted_average(x, 'score', 'number') for x in results_df_list])[0])
    print("N20 list:",N20_list)
    print("Network Accuracy Score list:",score_list)




