import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from transformers import T5Tokenizer, T5EncoderModel
from transformers import RobertaTokenizer, T5ForConditionalGeneration
import torch
from torch.utils.data import Dataset, DataLoader
import json
import os

# --------------- 数据加载与清洗 ---------------
# 假设数据文件为 CSV，包含 text 和 label 列
# df = pd.read_csv("data.csv")
# df["text"] = df["text"].astype(str)  # 强制转换为字符串
# df = df.dropna(subset=["text", "label"])  # 删除空值行

# # 划分数据集
# train_texts, val_texts, train_labels, val_labels = train_test_split(
#     df["text"].tolist(),
#     df["label"].tolist(),
#     test_size=0.2,
#     random_state=42,
#     stratify=df["label"]
# )
df_train = pd.read_csv("/home/longyuanjun/workspaces/github-code-clean/cwe_20_train_dataset_26489_14126_rl.csv")
codes_train = df_train["code"].tolist()
labels_train = df_train["label"].tolist()

df_test = pd.read_csv("/home/longyuanjun/workspaces/github-code-clean/cwe_20_test_dataset_6623_3531_rl.csv")
codes_test = df_test["code"].tolist()
labels_test= df_test["label"].tolist()

# --------------- 使用 CodeT5 分词器 ---------------
tokenizer = RobertaTokenizer.from_pretrained("/home/longyuanjun/models/codet5-base/")
max_length = 128  # 根据文本长度调整

# 编码文本（注意 CodeT5 需要添加前缀）
# def encode_texts(texts):
#     return tokenizer(
#         [f"分类任务: {text}" for text in texts],  # 添加任务前缀
#         truncation=True,
#         padding="max_length",
#         max_length=max_length,
#         return_tensors="pt"
#     )
train_encodings = tokenizer(
    codes_train, 
    truncation=True, 
    padding="max_length", 
    max_length=max_length, 
    return_tensors="pt"
)
test_encodings = tokenizer(
    codes_test, 
    truncation=True, 
    padding="max_length", 
    max_length=max_length, 
    return_tensors="pt"
)

# train_encodings = encode_texts(train_texts)
# val_encodings = encode_texts(val_texts)

# --------------- 构建 PyTorch Dataset ---------------
class TextDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "label": torch.tensor(self.labels[idx], dtype=torch.long)
        }

    def __len__(self):
        return len(self.labels)

train_dataset = TextDataset(train_encodings, labels_train)
test_dataset = TextDataset(test_encodings, labels_test)

# --------------- 创建 DataLoader ---------------
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=batch_size)

# ===================模型定义========================
from transformers import T5EncoderModel
import torch.nn as nn

class CodeT5LSTMClassifier(nn.Module):
    def __init__(self, lstm_hidden_size=256, num_classes=2):
        super().__init__()
        # CodeT5 编码器
        self.codet5 = T5EncoderModel.from_pretrained("/home/longyuanjun/models/codet5-base/")
        # 冻结预训练层（可选）
        for param in self.codet5.parameters():
            param.requires_grad = False
        
        # LSTM 层
        self.lstm = nn.LSTM(
            input_size=self.codet5.config.hidden_size,  # codet5-base 的 hidden_size=768
            hidden_size=lstm_hidden_size,
            batch_first=True,
            bidirectional=True
        )
        
        # 分类层
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(lstm_hidden_size * 2, num_classes)

    def forward(self, input_ids, attention_mask):
        # CodeT5 编码
        outputs = self.codet5(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state  # [batch_size, seq_len, 768]
        
        # LSTM 处理
        lstm_output, _ = self.lstm(sequence_output)
        last_output = lstm_output[:, -1, :]  # 取最后一个时间步的输出
        
        # 分类
        x = self.dropout(last_output)
        logits = self.fc(x)
        return logits

# ================ 模型训练与评估=========================
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = CodeT5LSTMClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 创建保存目录
os.makedirs("saved_models", exist_ok=True)
os.makedirs("metrics", exist_ok=True)

# 训练日志
metrics = {"train": [], "test": []}

for epoch in range(10):
    # --------------- 训练阶段 ---------------
    model.train()
    epoch_train_loss = 0
    all_train_preds = []
    all_train_labels = []
    
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        epoch_train_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)
        all_train_preds.extend(preds.cpu().tolist())
        all_train_labels.extend(labels.cpu().tolist())
    
    # 计算训练指标
    train_acc = accuracy_score(all_train_labels, all_train_preds)
    train_precision = precision_score(all_train_labels, all_train_preds, average="binary")
    train_recall = recall_score(all_train_labels, all_train_preds, average="binary")
    train_f1 = f1_score(all_train_labels, all_train_preds, average="binary")
    
    # --------------- 验证阶段 ---------------
    model.eval()
    all_test_preds = []
    all_test_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            
            outputs = model(input_ids, attention_mask)
            preds = torch.argmax(outputs, dim=1)
            all_test_preds.extend(preds.cpu().tolist())
            all_test_labels.extend(labels.cpu().tolist())
    
    # 计算验证指标
    test_acc = accuracy_score(all_test_labels, all_test_preds)
    test_precision = precision_score(all_test_labels, all_test_preds, average="binary")
    test_recall = recall_score(all_test_labels, all_test_preds, average="binary")
    test_f1 = f1_score(all_test_labels, all_test_preds, average="binary")
    
    # --------------- 保存模型和指标 ---------------
    # 保存当前 epoch 的模型
    # torch.save(
    #     model.state_dict(),
    #     f"saved_models/codet5_lstm_epoch_{epoch+1}.pth"
    # )
    
    # 记录指标
    epoch_metrics = {
        "epoch": epoch + 1,
        "train_loss": epoch_train_loss / len(train_loader),
        "train_accuracy": train_acc,
        "train_precision": train_precision,
        "train_recall": train_recall,
        "train_f1": train_f1,
        "val_accuracy": test_acc,
        "val_precision": test_precision,
        "val_recall": test_recall,
        "val_f1": test_f1
    }
    metrics["train"].append(epoch_metrics)
    
    # 打印结果
    print(f"\nEpoch {epoch+1}/10")
    print(f"Train Loss: {epoch_metrics['train_loss']:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
    print(f"Val Acc: {test_acc:.4f} | Precision: {test_precision:.4f} | Recall: {test_recall:.4f} | F1: {test_f1:.4f}")
    print("-" * 60)

# 保存指标到 JSON
with open("/home/longyuanjun/workspaces/vul-LMGGNN/train_record/codet5_lstm/cwe_20/training_metrics.json", "w") as f:
    json.dump(metrics, f, indent=4)

