from train import *
from feature import *



def predict_bind(task_name):
    task_dir = f"task/{task_name}/"
    df_path = f"{task_dir}/{task_name}_processed.csv"
    process_allobind_predict(task_name, cutoff=10, plm_model_name='ankh', pdb_db='af3', base_dir='task')
    lig_fp_dir = f"{task_dir}/ligand/morgan"
    protein_graph_dir = f"{task_dir}/protein/af3_ankh_10_graph"
    df = pd.read_csv(df_path, sep='\t')
    names = [f"{row['protein_name']}_{row['modulator_name']}" for _, row in df.iterrows()]
    best_model_dir = f"pretrained/AlloBench/af3_ankh_10"
    device = torch.device('cpu')
    batch_size = 12
    model = MMAlloBind(in_dim_dict={'prot_plm':1536, 'lig_fp':1024, 'prot_edge':15}, hidden_dim=512).to(device)
    predict_dataset = LigProtSiteDataset(df_path, 'modulator_name', lig_fp_dir, protein_graph_dir, 'protein_name', predict=True)
    predict_loader = DataLoader(predict_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn_bind)

    all_fold_probs = []
    final_mask = None

    for fold in range(5):
        best_model_path = f"{best_model_dir}/model_fold{fold}.pt"
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        predict_probs, predict_masks = [], []
        model.eval()
        with torch.no_grad():
             for lig_fp, prot_graph, labels, mask_prot in predict_loader:
                lig_fp, prot_graph, labels, mask_prot =  lig_fp.to(device), prot_graph.to(device), labels.to(device), mask_prot.to(device)
                logits = model(lig_fp, prot_graph, mask_prot).squeeze(-1)
                pocket_prob = torch.sigmoid(logits)
                predict_probs.append((pocket_prob * mask_prot).cpu())
                predict_masks.append(mask_prot.cpu())

        predict_probs = torch.cat(predict_probs, 0)
        predict_masks = torch.cat(predict_masks, 0)

        all_fold_probs.append(predict_probs)

        if final_mask is None:
            final_mask = predict_masks

    mean_probs = torch.stack(all_fold_probs, 0).mean(0)
    print(mean_probs.size())

    with open(f"{task_dir}/predict_mean_probs.txt", "w") as f:
        for i in range(mean_probs.shape[0]):
            valid_vals = mean_probs[i][final_mask[i] == 1]
            line = " ".join([f"{v:.6f}" for v in valid_vals.tolist()])
            f.write(line + "\n")

    for i in range(mean_probs.shape[0]):
        row_probs = mean_probs[i]
        row_mask = final_mask[i]

        valid_indices = torch.nonzero(row_mask == 1, as_tuple=False).squeeze(-1)
        valid_probs = row_probs[valid_indices]

        sorted_idx = torch.argsort(valid_probs, descending=True)
        sorted_indices = valid_indices[sorted_idx]
        sorted_probs = valid_probs[sorted_idx]

        out_path = f"{task_dir}/{names[i]}_sorted_probs.txt"
        with open(out_path, "w") as f:
            for idx, prob in zip(sorted_indices, sorted_probs):
                f.write(f"{str((int(idx.item())+1))} {prob.item():.6f}\n")
            
# predict_bind('example_bind')





def predict_mut(task_name):
    task_dir = f"task/{task_name}/"
    df_path = f"{task_dir}/{task_name}_processed.csv"
    process_allomut(task_name, cutoff=10, plm_model_name='onehot', pdb_db='af3', base_dir='task')
    protein_graph_dir = f"{task_dir}/protein/af3_onehot_10_graph"
    df = pd.read_csv(df_path, sep='\t')
    best_model_dir = f"pretrained/AM/af3_onehot_10_a"
    device = torch.device('cpu')
    batch_size = 12
    model = MMAlloMutate(in_dim_dict={'prot_plm':40, 'lig_fp':1024, 'prot_edge':15}, hidden_dim=256).to(device)
    predict_dataset = MutProtSiteDataset(df_path, protein_graph_dir, 'mut_name', predict=True)
    predict_loader = DataLoader(predict_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn_mut)
    all_fold_probs = []
    for fold in range(5):
        best_model_path = f"{best_model_dir}/model_fold{fold}.pt"
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        model.eval()
        with torch.no_grad():
             for prot_graphs, sites, labels, mask_prot in predict_loader:
                prot_graphs, sites, labels, mask_prot =  prot_graphs.to(device), sites.to(device), labels.to(device), mask_prot.to(device)
                logits = model(prot_graphs, sites, mask_prot).squeeze(-1)
                pocket_prob = torch.sigmoid(logits)
        all_fold_probs.append(pocket_prob)
    
    mean_val = torch.stack(all_fold_probs, dim=0).mean(dim=0)
    print(mean_val.size())

    with open(f"{task_dir}/result.txt", "w") as f:
        for i in range(mean_val.shape[0]):
            valid_vals = mean_val[i]
            f.write(f"{valid_vals.item():.3f}\n")

predict_mut('example_mut')
