import os
import re
import json

intervals = [(0, 100), (101, 200), (201, 300), (301, 400), (401, 500), (501, float('inf'))]

def compare_files(file1, file2):
    with open(file1, 'r', encoding='utf-8') as f1, open(file2, 'r', encoding='utf-8') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()

    total_lines = max(len(lines1), len(lines2))
    same_lines_count = sum(1 for line1, line2 in zip(lines1, lines2) if line1 == line2)

    if total_lines == 0:
        return 0.0
    else:
        print(same_lines_count)
        print(total_lines)
        return same_lines_count / total_lines
    
def extract_cwe_mapping(jsonl_file):
    cwe_mapping = {}
    with open(jsonl_file, 'r', encoding='utf-8') as file:
        for line_number, line in enumerate(file, start=1):
            data = json.loads(line)
            source = data.get('source', '')
            cwe_match = re.match(r'^CWE-\d{1,3}\s', source)
            if cwe_match:
                cwe = cwe_match.group(0).strip()
                cwe_mapping[line_number] = cwe
    return cwe_mapping   

def count_cwe_occurrences(cwe_mapping,interested_cwes):
    cwe_counts = {cwe: 0 for cwe in interested_cwes} 
    for cwe in cwe_mapping.values():
        if cwe in interested_cwes:
            cwe_counts[cwe] += 1
    return cwe_counts

def compare_and_count_cwe(file1, file2, mapping, interested_cwes):
    cwe_counts = {cwe: 0 for cwe in interested_cwes}
    
    with open(file1, 'r', encoding='utf-8') as f1, open(file2, 'r', encoding='utf-8') as f2:
        for line_number, (line1, line2) in enumerate(zip(f1, f2), start=1):
            line1 = line1.strip()
            line2 = line2.strip()

            if line1 == line2:
                cwe = mapping.get(line_number, '')
                if cwe in interested_cwes:
                    cwe_counts[cwe] += 1

    return cwe_counts

def calculate_cwe_percentage(identical_cwe_counts, total_cwe_counts):
    cwe_percentages = {}
    for cwe, count in identical_cwe_counts.items():
        total_count = total_cwe_counts.get(cwe, 0)
        if total_count != 0:
            percentage = (count / total_count) * 100
            cwe_percentages[cwe] = percentage
    return cwe_percentages

def length_statistic():
    
    length_intervals_em=defaultdict(int)
    length_intervals_all = defaultdict(int)
    
    with open(gold_file, "r", encoding="utf-8") as file_A, \
        open(output_file, "r", encoding="utf-8") as file_B, \
        open(test_filename, "r", encoding="utf-8") as file_input:
        for line_A, line_B, line_input in zip(file_A, file_B, file_input):
            source_text_A = line_A.strip()
            source_text_B = line_B.strip()
                
            json_data = json.loads(line_input)
            source_text = json_data["source"]
            tokenized_source_all = tokenizer.encode(source_text, return_tensors="pt")
            source_length_input = tokenized_source_all.size(1)
            
            for interval in intervals:
                if interval[0] <= source_length_input <= interval[1]:
                    length_intervals_all[interval] += 1            
            
            if source_text_A == source_text_B:
                for interval in intervals:
                    if interval[0] <= source_length_input <= interval[1]:
                        length_intervals_em[interval] += 1                
                
    proportions = {}
    for interval in intervals:
        if length_intervals_all[interval] != 0:
            proportion = length_intervals_em[interval] / length_intervals_all[interval]
            proportions[interval] = {
                "ratio_string": f"{length_intervals_em[interval]} / {length_intervals_all[interval]}",
                "ratio_float": proportion
            }
        else:
            proportions[interval] = {
                "ratio_string": "0 / 0",
                "ratio_float": 0.0
            }

    for interval, values in proportions.items():
        ratio_string = values["ratio_string"]
        ratio_float = values["ratio_float"] * 100 
        print(f"{interval} 比例数值: {ratio_string}  百分比: {ratio_float:.2f}%")
               
       
def input_length_statistic():

    length_intervals_src=defaultdict(int)
    length_intervals_ret=defaultdict(int)
    with open(train_filename, "r", encoding="utf-8") as file:
        for line in file:
            json_data = json.loads(line)
            source_text = json_data["source"]
            tokenized_line_src = tokenizer.encode(source_text, return_tensors="pt")
            line_length_src = tokenized_line_src.size(1)
            
            for interval in intervals:
                if interval[0] <= line_length_src <= interval[1]:
                    length_intervals_src[interval] += 1
    
    with open(retrieval_filename, "r", encoding="utf-8") as file:
        for line in file:
            json_data = json.loads(line)
            retrieval_text = json_data["retrived"]
            if retrieval_text:
                tokenized_line_ret = tokenizer.encode(retrieval_text, return_tensors="pt")
            line_length_ret = tokenized_line_ret.size(1)
            
            for interval in intervals:
                if interval[0] <= line_length_ret <= interval[1]:
                    length_intervals_ret[interval] += 1
    print("source:")
    for interval in intervals:
        count = length_intervals_src[interval]
        if interval[1] == float('inf'):
            print(f"> {interval[0]}: {count}")
        else:
            print(f"{interval[0]}-{interval[1]}: {count}")
    print("retrieved:")        
    for interval in intervals:
        count = length_intervals_ret[interval]
        if interval[1] == float('inf'):
            print(f"> {interval[0]}: {count}")
        else:
            print(f"{interval[0]}-{interval[1]}: {count}")

