import os
import sys
import glob
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, global_max_pool, SAGPooling
from torch_geometric.data import Dataset, Batch
from torch_geometric.loader import DataLoader
from sklearn.metrics import (confusion_matrix, roc_curve, auc, 
                             precision_recall_curve, average_precision_score, 
                             accuracy_score, brier_score_loss)
from scipy.interpolate import make_interp_spline

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['font.family'] = 'serif'

class MCDropout(nn.Module):
    def __init__(self, p=0.5):
        super(MCDropout, self).__init__()
        self.p = p

    def forward(self, x):
        return F.dropout(x, p=self.p, training=True)

class BayesianGCN(nn.Module):
    def __init__(self, num_node_features, hidden_channels, num_classes, dropout_rate=0.5, num_layers=3):
        super(BayesianGCN, self).__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.mc_dropout = MCDropout(p=dropout_rate)
        
        self.convs.append(GCNConv(num_node_features, hidden_channels))
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))
            
        self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        
        self.pool = SAGPooling(hidden_channels, ratio=0.8)
        
        self.lin1 = nn.Linear(hidden_channels * 2, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, num_classes)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        edge_weight = edge_attr.squeeze() if edge_attr is not None else None
        
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_weight)
            x = self.bns[i](x)
            x = F.relu(x)
            x = self.mc_dropout(x)
        
        x, edge_index, edge_attr, batch, _, _ = self.pool(x, edge_index, edge_attr, batch)
        
        self.final_conv_acts = x
        self.final_batch = batch
        
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)
        
        x = self.lin1(x)
        x = F.relu(x)
        x = self.mc_dropout(x)
        x = self.lin2(x)
        
        return x

class GraphGradCAM:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.gradients = None
        self.activations = None
        self.handlers = []
        
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output
            
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        target_layer = self.model.convs[-1]
        self.handlers.append(target_layer.register_forward_hook(forward_hook))
        self.handlers.append(target_layer.register_full_backward_hook(backward_hook))

    def remove_hooks(self):
        for handle in self.handlers:
            handle.remove()

    def generate_cam(self, data, target_class=None):
        self.model.eval()
        self.model.zero_grad()
        
        data = data.to(self.device)
        out = self.model(data.x, data.edge_index, data.edge_attr, data.batch)
        
        if target_class is None:
            target_class = out.argmax(dim=1).item()
            
        target = out[0, target_class]
        target.backward()
        
        grads = self.gradients 
        acts = self.activations
        
        if grads is None or acts is None:
            return None

        weights = torch.mean(grads, dim=0)
        cam = torch.sum(weights * acts, dim=1)
        cam = F.relu(cam)
        
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        return cam.detach().cpu().numpy()

class CalibrationEvaluator:
    def __init__(self, n_bins=15):
        self.n_bins = n_bins

    def compute_ece(self, probs, labels):
        confidences = np.max(probs, axis=1)
        predictions = np.argmax(probs, axis=1)
        accuracies = predictions == labels
        
        ece = 0.0
        bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
        
        bin_lowers = bin_boundaries[:-1]
        bin_uppers = bin_boundaries[1:]
        
        for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = np.mean(in_bin)
            
            if prop_in_bin > 0:
                accuracy_in_bin = np.mean(accuracies[in_bin])
                avg_confidence_in_bin = np.mean(confidences[in_bin])
                ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
                
        return ece

    def plot_reliability_diagram(self, probs, labels, save_path):
        confidences = np.max(probs, axis=1)
        predictions = np.argmax(probs, axis=1)
        accuracies = predictions == labels
        
        bin_boundaries = np.linspace(0, 1, self.n_bins + 1)
        bin_lowers = bin_boundaries[:-1]
        
        bin_accs = []
        bin_confs = []
        
        for bin_lower, bin_upper in zip(bin_lowers, bin_boundaries[1:]):
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            if np.sum(in_bin) > 0:
                bin_accs.append(np.mean(accuracies[in_bin]))
                bin_confs.append(np.mean(confidences[in_bin]))
            else:
                bin_accs.append(0)
                bin_confs.append((bin_lower + bin_upper) / 2)

        plt.figure(figsize=(8, 8))
        plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfect Calibration')
        plt.plot(bin_confs, bin_accs, marker='o', linewidth=2, label='Model')
        plt.bar(bin_lowers, bin_accs, width=1.0/self.n_bins, alpha=0.3, align='edge', edgecolor='black')
        plt.xlabel('Confidence')
        plt.ylabel('Accuracy')
        plt.title('Reliability Diagram')
        plt.legend()
        plt.savefig(save_path, dpi=300)
        plt.close()

class Visualizer:
    def __init__(self, patch_size=50):
        self.patch_size = patch_size

    def generate_heatmap(self, coords, cam_weights, save_path, slide_id):
        if len(coords) == 0:
            return

        coords = coords.astype(np.int32)
        min_x, min_y = np.min(coords, axis=0)
        max_x, max_y = np.max(coords, axis=0)
        
        w = max_x - min_x + self.patch_size
        h = max_y - min_y + self.patch_size
        
        sf = 0.1 
        w_small = int(w * sf)
        h_small = int(h * sf)
        
        canvas = np.ones((h_small, w_small, 3), dtype=np.uint8) * 255
        heatmap_overlay = np.zeros((h_small, w_small), dtype=np.float32)
        
        scaled_coords = ((coords - [min_x, min_y]) * sf).astype(np.int32)
        scaled_ps = int(self.patch_size * sf)
        
        for i, (x, y) in enumerate(scaled_coords):
            weight = cam_weights[i] if i < len(cam_weights) else 0
            cv2.rectangle(heatmap_overlay, (x, y), (x + scaled_ps, y + scaled_ps), weight, -1)
            
        heatmap_overlay = cv2.GaussianBlur(heatmap_overlay, (15, 15), 0)
        heatmap_overlay = (heatmap_overlay - heatmap_overlay.min()) / (heatmap_overlay.max() - heatmap_overlay.min() + 1e-8)
        
        heatmap_color = cv2.applyColorMap(np.uint8(255 * heatmap_overlay), cv2.COLORMAP_JET)
        
        tissue_mask = heatmap_overlay > 0.05
        tissue_mask = np.stack([tissue_mask]*3, axis=-1)
        
        final_img = np.where(tissue_mask, 
                           cv2.addWeighted(canvas, 0.3, heatmap_color, 0.7, 0), 
                           canvas)
        
        cv2.putText(final_img, f"Slide: {slide_id}", (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,0), 2)
        cv2.imwrite(save_path, final_img)

def plot_confusion_matrix(y_true, y_pred, classes, save_path):
    cm = confusion_matrix(y_true, y_pred)
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Normalized Confusion Matrix')
    plt.savefig(save_path, dpi=300)
    plt.close()

def plot_uncertainty_dist(results_df, save_dir):
    plt.figure(figsize=(10, 6))
    
    for label in sorted(results_df['true_label'].unique()):
        subset = results_df[results_df['true_label'] == label]
        sns.kdeplot(subset['entropy'], label=f'Class {label}', fill=True, alpha=0.3)
        
    plt.xlabel('Predictive Entropy (Uncertainty)')
    plt.ylabel('Density')
    plt.title('Uncertainty Distribution by Class')
    plt.legend()
    plt.savefig(os.path.join(save_dir, 'uncertainty_dist.png'), dpi=300)
    plt.close()
    
    correct = results_df[results_df['true_label'] == results_df['pred_label']]
    incorrect = results_df[results_df['true_label'] != results_df['pred_label']]
    
    plt.figure(figsize=(10, 6))
    sns.kdeplot(correct['entropy'], label='Correct Predictions', fill=True, color='green', alpha=0.3)
    sns.kdeplot(incorrect['entropy'], label='Incorrect Predictions', fill=True, color='red', alpha=0.3)
    plt.xlabel('Predictive Entropy')
    plt.title('Uncertainty: Correct vs Incorrect')
    plt.legend()
    plt.savefig(os.path.join(save_dir, 'uncertainty_correct_vs_error.png'), dpi=300)
    plt.close()

def generate_full_report(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    test_results = pd.read_csv(args.test_csv)
    y_true = test_results['true_label'].values
    y_pred = test_results['pred_label'].values
    probs = test_results[['prob_0', 'prob_1', 'prob_2']].values
    
    metrics_dir = os.path.join(args.output_dir, 'metrics')
    viz_dir = os.path.join(args.output_dir, 'visualizations')
    os.makedirs(metrics_dir, exist_ok=True)
    os.makedirs(viz_dir, exist_ok=True)
    
    classes = ['Tuberculosis', 'Brucellosis', 'Pyogenic']
    
    print("Generating Confusion Matrix...")
    plot_confusion_matrix(y_true, y_pred, classes, os.path.join(metrics_dir, 'confusion_matrix.png'))
    
    print("Generating ROC Curves...")
    plt.figure(figsize=(10, 8))
    for i, cls in enumerate(classes):
        fpr, tpr, _ = roc_curve(y_true == i, probs[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{cls} (AUC = {roc_auc:.2f})')
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multiclass ROC Curve')
    plt.legend()
    plt.savefig(os.path.join(metrics_dir, 'roc_curve.png'), dpi=300)
    plt.close()
    
    print("Analyzing Calibration...")
    calib = CalibrationEvaluator(n_bins=15)
    ece = calib.compute_ece(probs, y_true)
    print(f"Expected Calibration Error (ECE): {ece:.4f}")
    calib.plot_reliability_diagram(probs, y_true, os.path.join(metrics_dir, 'reliability_diagram.png'))
    
    print("Analyzing Uncertainty...")
    plot_uncertainty_dist(test_results, metrics_dir)
    
    with open(os.path.join(metrics_dir, 'summary_metrics.json'), 'w') as f:
        json.dump({
            'accuracy': accuracy_score(y_true, y_pred),
            'ece': ece,
            'brier_score': brier_score_loss(y_true == y_pred, np.max(probs, axis=1)), 
        }, f, indent=4)

    if args.generate_heatmaps:
        print("Generating Grad-CAM Heatmaps for selected cases...")
        
        model = BayesianGCN(num_node_features=args.feat_dim, 
                           hidden_channels=256, 
                           num_classes=3).to(device)
        model.load_state_dict(torch.load(args.model_path, map_location=device))
        
        grad_cam = GraphGradCAM(model, device)
        visualizer = Visualizer(patch_size=args.patch_size)
        
        target_ids = test_results.sort_values('entropy', ascending=False).head(5)['slide_id'].tolist()
        target_ids += test_results.sort_values('entropy', ascending=True).head(5)['slide_id'].tolist()
        
        for slide_id in target_ids:
            graph_path = os.path.join(args.graph_dir, f"{slide_id}.pt")
            if not os.path.exists(graph_path):
                continue
                
            data = torch.load(graph_path)
            
            cam_weights = grad_cam.generate_cam(Batch.from_data_list([data]))
            
            if cam_weights is not None:
                coords = data.pos.numpy() if hasattr(data, 'pos') else np.zeros((len(cam_weights), 2))
                
                save_path = os.path.join(viz_dir, f"heatmap_{slide_id}_pred{data.y.item()}.png")
                visualizer.generate_heatmap(coords, cam_weights, save_path, slide_id)
        
        grad_cam.remove_hooks()
        print("Heatmap generation complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="WSI Classification Evaluation and Explainability")
    parser.add_argument('--test_csv', type=str, required=True, help='Output CSV from Module 4')
    parser.add_argument('--output_dir', type=str, default='./evaluation_results')
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--graph_dir', type=str, required=True, help='Directory containing processed .pt graph files')
    parser.add_argument('--generate_heatmaps', action='store_true')
    parser.add_argument('--feat_dim', type=int, default=1024)
    parser.add_argument('--patch_size', type=int, default=1024)
    
    args = parser.parse_args()
    
    generate_full_report(args)
