import os
import copy
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, SAGPooling
from torch_geometric.loader import DataLoader
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
import numpy as np
import pandas as pd
from tqdm import tqdm

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class MCDropout(nn.Module):
    def __init__(self, p=0.5):
        super(MCDropout, self).__init__()
        self.p = p

    def forward(self, x):
        return F.dropout(x, p=self.p, training=True)

class BayesianGCN(nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes, dropout_rate=0.5, num_layers=3):
        super(BayesianGCN, self).__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.mc_dropout = MCDropout(p=dropout_rate)
        
        self.convs.append(GCNConv(num_node_features, hidden_channels))
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))
            
        self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        
        self.pool = SAGPooling(hidden_channels, ratio=0.8)
        
        self.lin1 = nn.Linear(hidden_channels * 2, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        edge_weight = edge_attr.squeeze() if edge_attr is not None else None
        
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_weight)
            x = self.bns[i](x)
            x = F.relu(x)
            x = self.mc_dropout(x)
        
        x, edge_index, edge_attr, batch, _, _ = self.pool(x, edge_index, edge_attr, batch)
        
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)
        
        x = self.lin1(x)
        x = F.relu(x)
        x = self.mc_dropout(x)
        x = self.lin2(x)
        
        return x

class UncertaintyEngine:
    def __init__(self, model, device, num_mc_samples=50):
        self.model = model
        self.device = device
        self.T = num_mc_samples

    def predict(self, loader):
        self.model.train() 
        all_means = []
        all_entropies = []
        all_targets = []
        all_ids = []

        with torch.no_grad():
            for data in tqdm(loader, desc="Uncertainty Estimation"):
                data = data.to(self.device)
                batch_preds = []
                
                for _ in range(self.T):
                    logits = self.model(data.x, data.edge_index, data.edge_attr, data.batch)
                    probs = F.softmax(logits, dim=1)
                    batch_preds.append(probs.unsqueeze(0))
                
                batch_preds = torch.cat(batch_preds, dim=0)
                
                mean_probs = torch.mean(batch_preds, dim=0)
                
                predictive_entropy = -torch.sum(mean_probs * torch.log(mean_probs + 1e-8), dim=1)
                
                all_means.append(mean_probs.cpu().numpy())
                all_entropies.append(predictive_entropy.cpu().numpy())
                all_targets.append(data.y.cpu().numpy())
                if hasattr(data, 'slide_id'):
                    if isinstance(data.slide_id, list):
                        all_ids.extend(data.slide_id)
                    else:
                        all_ids.extend(data.slide_id)

        return np.concatenate(all_means), np.concatenate(all_entropies), np.concatenate(all_targets), all_ids

class ModelTrainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.model = BayesianGCN(
            num_node_features=config['in_dim'],
            hidden_channels=config['hidden_dim'],
            num_classes=config['num_classes'],
            dropout_rate=config['dropout'],
            num_layers=config['layers']
        ).to(self.device)
        
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), 
            lr=config['lr'], 
            weight_decay=config['weight_decay']
        )
        
        class_weights = torch.tensor(config['class_weights']).float().to(self.device) if config['class_weights'] else None
        self.criterion = FocalLoss(alpha=class_weights, gamma=2.0)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.5, patience=10)

    def train_epoch(self, loader):
        self.model.train()
        total_loss = 0
        all_preds = []
        all_labels = []

        for data in loader:
            data = data.to(self.device)
            self.optimizer.zero_grad()
            
            out = self.model(data.x, data.edge_index, data.edge_attr, data.batch)
            loss = self.criterion(out, data.y)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_loss += loss.item() * data.num_graphs
            preds = out.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(data.y.cpu().numpy())

        avg_loss = total_loss / len(loader.dataset)
        acc = accuracy_score(all_labels, all_preds)
        return avg_loss, acc

    def evaluate(self, loader, use_mc=False):
        if use_mc:
            engine = UncertaintyEngine(self.model, self.device, num_mc_samples=10)
            probs, _, targets, _ = engine.predict(loader)
            preds = np.argmax(probs, axis=1)
            loss = 0.0 
        else:
            self.model.eval() 
            total_loss = 0
            all_preds = []
            all_labels = []
            
            with torch.no_grad():
                for data in loader:
                    data = data.to(self.device)
                    out = self.model(data.x, data.edge_index, data.edge_attr, data.batch)
                    loss = self.criterion(out, data.y)
                    total_loss += loss.item() * data.num_graphs
                    preds = out.argmax(dim=1)
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(data.y.cpu().numpy())
            
            loss = total_loss / len(loader.dataset)
            preds = np.array(all_preds)
            targets = np.array(all_labels)

        acc = accuracy_score(targets, preds)
        f1 = f1_score(targets, preds, average='macro')
        return loss, acc, f1

    def fit(self, train_loader, val_loader, epochs=100, save_path='best_model.pth'):
        best_f1 = 0.0
        patience = 20
        counter = 0
        
        for epoch in range(epochs):
            train_loss, train_acc = self.train_epoch(train_loader)
            val_loss, val_acc, val_f1 = self.evaluate(val_loader)
            
            self.scheduler.step(val_loss)
            
            print(f'Epoch {epoch+1:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}')
            
            if val_f1 > best_f1:
                best_f1 = val_f1
                torch.save(self.model.state_dict(), save_path)
                counter = 0
            else:
                counter += 1
                if counter >= patience:
                    print("Early stopping triggered")
                    break
        
        print(f"Training complete. Best F1: {best_f1:.4f}")

def main():
    import argparse
    from torch_geometric.data import Dataset
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_dir', type=str, required=True)
    parser.add_argument('--save_dir', type=str, default='./models')
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--mc_samples', type=int, default=50)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--hidden_dim', type=int, default=256)
    parser.add_argument('--dropout', type=float, default=0.5)
    args = parser.parse_args()

    os.makedirs(args.save_dir, exist_ok=True)

    class WSIGraphDataset(Dataset):
        def __init__(self, root):
            self.root = root
            self.files = sorted([os.path.join(root, f) for f in os.listdir(root) if f.endswith('.pt')])
            super().__init__(root)
        
        def len(self):
            return len(self.files)
        
        def get(self, idx):
            return torch.load(self.files[idx])

    full_dataset = WSIGraphDataset(args.dataset_dir)
    
    total_size = len(full_dataset)
    train_size = int(0.7 * total_size)
    val_size = int(0.15 * total_size)
    test_size = total_size - train_size - val_size
    
    train_set, val_set, test_set = torch.utils.data.random_split(full_dataset, [train_size, val_size, test_size])
    
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

    sample_data = full_dataset[0]
    num_features = sample_data.x.shape[1]
    num_classes = 3 

    label_counts = torch.zeros(num_classes)
    for data in train_loader:
        labels = data.y
        for c in range(num_classes):
            label_counts[c] += (labels == c).sum()
    
    weights = 1.0 / (label_counts + 1e-6)
    weights = weights / weights.sum()
    print(f"Class weights: {weights}")

    config = {
        'in_dim': num_features,
        'hidden_dim': args.hidden_dim,
        'num_classes': num_classes,
        'dropout': args.dropout,
        'layers': 3,
        'lr': args.lr,
        'weight_decay': 5e-4,
        'class_weights': weights.tolist()
    }

    trainer = ModelTrainer(config)
    save_path = os.path.join(args.save_dir, 'bayesian_gcn.pth')
    
    trainer.fit(train_loader, val_loader, epochs=args.epochs, save_path=save_path)
    
    print("Loading best model for testing...")
    trainer.model.load_state_dict(torch.load(save_path))
    
    unc_engine = UncertaintyEngine(trainer.model, trainer.device, num_mc_samples=args.mc_samples)
    probs, entropies, targets, ids = unc_engine.predict(test_loader)
    
    final_preds = np.argmax(probs, axis=1)
    test_acc = accuracy_score(targets, final_preds)
    test_f1 = f1_score(targets, final_preds, average='macro')
    
    print(f"Test Accuracy (MC): {test_acc:.4f}")
    print(f"Test F1 (MC): {test_f1:.4f}")
    
    results_df = pd.DataFrame({
        'slide_id': ids,
        'true_label': targets,
        'pred_label': final_preds,
        'entropy': entropies,
        'prob_0': probs[:, 0],
        'prob_1': probs[:, 1],
        'prob_2': probs[:, 2]
    })
    
    results_df.to_csv(os.path.join(args.save_dir, 'test_results_uncertainty.csv'), index=False)
    
    high_uncertainty_threshold = np.percentile(entropies, 90)
    uncertain_cases = results_df[results_df['entropy'] > high_uncertainty_threshold]
    print(f"Identified {len(uncertain_cases)} high uncertainty cases (Threshold > {high_uncertainty_threshold:.4f})")

if __name__ == "__main__":
    main()
