import os
import re
import json


def read_jsonl_file(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def get_repair_result(data):
    correct_count = 0
    result = []
    for entry in data:
        output = entry['output']
        beam_output = entry['beam_output']
        flag = False
        for beam_item in beam_output:
            if beam_item == output:
                correct_count += 1
                flag = True
                break
        if flag:
            result.append(1)
        else:
            result.append(0)
    return result

def get_v_repair_result(data,trap_data):
    correct_count = 0
    result = []
    for correct,entry in zip(trap_data,data):
        flag = False
        answer = correct['output']
        beam_output = entry['beam_output']
        for beam_item in beam_output:
            if beam_item == answer:
                correct_count += 1
                flag = True
                break
        if flag:
            result.append(1)
        else:
            result.append(0)
    return result
        
def extract_cwe_mapping(jsonl_file):
    cwe_mapping = {}
    with open(jsonl_file, 'r', encoding='utf-8') as file:
        for line_number, line in enumerate(file, start=1):
            data = json.loads(line)
            source = data.get('source', '')
            cwe_match = re.match(r'^CWE-\d{1,3}\s', source)
            if cwe_match:
                cwe = cwe_match.group(0).strip()
                cwe_mapping[line_number] = cwe
    return cwe_mapping 

def count_cwe_occurrences(cwe_mapping,interested_cwes):
    cwe_counts = {cwe: 0 for cwe in interested_cwes} 
    for cwe in cwe_mapping.values():
        if cwe in interested_cwes:
            cwe_counts[cwe] += 1
    return cwe_counts

def cwe_compare(trap_file,mapping,interested_cwes):
    cwe_counts = {cwe: 0 for cwe in interested_cwes}
    with open(trap_file, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            source = data.get('source', '')
            cwe_match = re.match(r'^CWE-\d{1,3}\s', source)
            if cwe_match:
                cwe = cwe_match.group(0).strip()
                if cwe in interested_cwes:
                    cwe_counts[cwe] += 1
                 
def compare_cwe_count(data,mapping,interested_cwes):
    result = []
    cwe_counts = {cwe: 0 for cwe in interested_cwes}
    correct_count = 0
    for entry in data:
        output = entry['output']
        beam_output = entry['beam_output']

        flag = False
        for beam_item in beam_output:
            if beam_item == output:
                correct_count += 1
                flag = True
                break
        if flag:
            result.append(1)
        else:
            result.append(0)
    
    for idx,item in enumerate(result):
        if item == 1:
            cwe = mapping.get(idx,'')
            if cwe in interested_cwes:
                cwe_counts[cwe] += 1

    return cwe_counts        

def calculate_cwe_percentage(identical_cwe_counts, total_cwe_counts):
    cwe_percentages = {}
    proportion = {}
    all_count = 0
    all_total_count = 0
    for cwe, count in identical_cwe_counts.items():
        total_count = total_cwe_counts.get(cwe, 0)
        if total_count != 0:
            percentage = (count / total_count) * 100
            cwe_percentages[cwe] = percentage
            proportion[cwe] = "{}/{}".format(count,total_count)
            all_count += count
            all_total_count += total_count
    total_percentage = (all_count/all_total_count) * 100
    print(f"Total intrest: {all_count}/{all_total_count}   {total_percentage:.2f}%")
    return cwe_percentages,proportion


def get_unique_cwe(result,cwe_dict_test,type,result_file):
    vrepair_cwe = set()
    vulrepair_cwe = set()
    trap_cwe = set()
    for idx,item in enumerate(result):
        if item == 1:
            cur_cwe = cwe_dict_test.get(idx,'')
            if cur_cwe not in vrepair_cwe:
                vrepair_cwe.add(cur_cwe)
    with open(result_file, 'w', encoding='utf-8') as f:
        for item in vrepair_cwe:
            f.write(item + '\n')



