import os
import glob
import logging
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.utils import from_scipy_sparse_matrix, dense_to_sparse
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
from scipy import sparse
from tqdm import tqdm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class GraphConfig:
    def __init__(self, 
                 k_neighbors=8, 
                 radius=2048, 
                 feature_threshold=0.5,
                 lambda_weight=0.5,
                 spatial_sigma=100.0,
                 prune_ratio=0.0):
        self.k = k_neighbors
        self.radius = radius
        self.feat_thresh = feature_threshold
        self.lamb = lambda_weight
        self.sigma = spatial_sigma
        self.prune_ratio = prune_ratio

class GraphBuilder:
    def __init__(self, config: GraphConfig):
        self.config = config

    def build(self, features: torch.Tensor, coords: torch.Tensor, roi_scores: torch.Tensor) -> Data:
        num_nodes = features.shape[0]
        
        if self.config.prune_ratio > 0:
            k_keep = int(num_nodes * (1 - self.config.prune_ratio))
            if k_keep < num_nodes:
                _, keep_indices = torch.topk(roi_scores.squeeze(), k_keep)
                features = features[keep_indices]
                coords = coords[keep_indices]
                roi_scores = roi_scores[keep_indices]
                num_nodes = k_keep

        coords_np = coords.numpy()
        features_np = features.numpy()

        nbrs = NearestNeighbors(n_neighbors=self.config.k + 1, algorithm='ball_tree').fit(coords_np)
        spatial_dists, spatial_indices = nbrs.kneighbors(coords_np)

        spatial_indices = spatial_indices[:, 1:]
        spatial_dists = spatial_dists[:, 1:]

        source_nodes = np.repeat(np.arange(num_nodes), self.config.k)
        target_nodes = spatial_indices.flatten()
        spatial_weights = np.exp(- (spatial_dists.flatten() ** 2) / (2 * self.config.sigma ** 2))

        features_norm = F.normalize(features, p=2, dim=1).numpy()
        
        feat_sim_list = []
        for i in range(0, len(source_nodes), 10000):
            end = min(i + 10000, len(source_nodes))
            src_idx = source_nodes[i:sys
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch_geometric.data import Data, Dataset
from torch_geometric.utils import to_undirected, add_self_loops
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm
import warnings

warnings.filterwarnings('ignore')

class DynamicEdgeConstructor:
    def __init__(self, k_neighbors=10, sigma=256.0, lambda_val=0.6, adaptive_pruning=True):
        self.k = k_neighbors
        self.sigma = sigma
        self.lambda_val = lambda_val
        self.adaptive_pruning = adaptive_pruning

    def compute_spatial_weights(self, dists):
        return torch.exp(-torch.tensor(dists**2) / (self.sigma**2))

    def compute_feature_sim(self, features, edge_index):
        src_feats = features[edge_index[0]]
        dst_feats = features[edge_index[1]]
        
        src_norm = F.normalize(src_feats, p=2, dim=1)
        dst_norm = F.normalize(dst_feats, p=2, dim=1)
        
        return F.cosine_similarity(src_norm, dst_norm, dim=1)

    def build(self, coords, features):
        num_nodes = coords.shape[0]
        
        if num_nodes <= self.k:
            k_eff = num_nodes - 1
        else:
            k_eff = self.k

        if k_eff < 1:
            return torch.empty((2, 0), dtype=torch.long), torch.empty(0)

        nbrs = NearestNeighbors(n_neighbors=k_eff + 1, algorithm='ball_tree').fit(coords)
        distances, indices = nbrs.kneighbors(coords)

        src_indices = np.repeat(np.arange(num_nodes), k_eff)
        dst_indices = indices[:, 1:].flatten()
        spatial_dists = distances[:, 1:].flatten()

        edge_index = torch.tensor(np.stack([src_indices, dst_indices]), dtype=torch.long)
        
        spatial_w = self.compute_spatial_weights(spatial_dists).float()
        feature_w = self.compute_feature_sim(features, edge_index)
        
        edge_attr = self.lambda_val * feature_w + (1 - self.lambda_val) * spatial_w
        
        if self.adaptive_pruning:
            mask = edge_attr > 0.1
            edge_index = edge_index[:, mask]
            edge_attr = edge_attr[mask]

        edge_index, edge_attr = to_undirected(edge_index, edge_attr)
        
        edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=1.0, num_nodes=num_nodes)
        
        return edge_index, edge_attr.unsqueeze(1)

class WSIGraphBuilder:
    def __init__(self, raw_dir, output_dir, label_file, config=None):
        self.raw_dir = raw_dir
        self.output_dir = output_dir
        self.label_df = self._load_labels(label_file)
        self.label_map = self._create_label_map()
        
        default_config = {'k_neighbors': 8, 'sigma': 128.0, 'lambda_val': 0.5}
        if config:
            default_config.update(config)
            
        self.constructor = DynamicEdgeConstructor(**default_config)
        os.makedirs(output_dir, exist_ok=True)

    def _load_labels(self, path):
        if path.endswith('.csv'):
            return pd.read_csv(path)
        return pd.read_excel(path)

    def _create_label_map(self):
        unique_labels = sorted(self.label_df['diagnosis'].unique())
        return {label: i for i, label in enumerate(unique_labels)}

    def process(self):
        feature_files = [f for f in os.listdir(self.raw_dir) if f.endswith('.pt')]
        success_count = 0
        
        for f_name in tqdm(feature_files, desc="Building Graphs"):
            slide_id = os.path.splitext(f_name)[0]
            
            label_row = self.label_df[self.label_df['slide_id'] == slide_id]
            if label_row.empty:
                continue
                
            label_str = label_row.iloc[0]['diagnosis']
            label = torch.tensor([self.label_map[label_str]], dtype=torch.long)
            
            raw_data = torch.load(os.path.join(self.raw_dir, f_name))
            features = raw_data['features']
            coords = raw_data['coords']
            if isinstance(coords, torch.Tensor):
                coords = coords.numpy()
            roi_scores = raw_data['roi_scores']

            roi_mask = roi_scores > 0.5
            if roi_mask.sum() < 50:
                topk = min(50, len(roi_scores))
                _, topk_indices = torch.topk(roi_scores, topk)
                roi_mask = torch.zeros_like(roi_mask)
                roi_mask[topk_indices] = True
            
            filtered_features = features[roi_mask]
            filtered_coords = coords[roi_mask.cpu().numpy()]
            
            if len(filtered_coords) < 5:
                continue

            edge_index, edge_attr = self.constructor.build(filtered_coords, filtered_features)
            
            data = Data(x=filtered_features,
                        edge_index=edge_index,
                        edge_attr=edge_attr,
                        y=label,
                        pos=torch.tensor(filtered_coords, dtype=torch.float),
                        slide_id=slide_id)
            
            torch.save(data, os.path.join(self.output_dir, f'{slide_id}.pt'))
            success_count += 1
            
        return success_count

class WSIGraphDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        self._processed_dir = root
        super(WSIGraphDataset, self).__init__(root, transform, pre_transform)
        self.files = sorted([os.path.join(self._processed_dir, f) 
                           for f in os.listdir(self._processed_dir) 
                           if f.endswith('.pt')])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return [os.path.basename(f) for f in self.files]
    
    @property
    def processed_dir(self):
        return self._processed_dir

    def len(self):
        return len(self.files)

    def get(self, idx):
        return torch.load(self.files[idx])

def get_data_stats(dataset):
    num_nodes = []
    num_edges = []
    
    for i in range(len(dataset)):
        data = dataset[i]
        num_nodes.append(data.num_nodes)
        num_edges.append(data.num_edges)
        
    return {
        'avg_nodes': np.mean(num_nodes),
        'max_nodes': np.max(num_nodes),
        'avg_edges': np.mean(num_edges),
        'density': np.mean(num_edges) / (np.mean(num_nodes) ** 2)
    }

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--feature_dir', type=str, required=True)
    parser.add_argument('--output_dir', type=str, default='./processed_graphs')
    parser.add_argument('--label_file', type=str, required=True)
    parser.add_argument('--k', type=int, default=12)
    parser.add_argument('--sigma', type=float, default=200.0)
    parser.add_argument('--lambda_val', type=float, default=0.6)
    
    args = parser.parse_args()
    
    config = {
        'k_neighbors': args.k,
        'sigma': args.sigma,
        'lambda_val': args.lambda_val
    }
    
    builder = WSIGraphBuilder(
        raw_dir=args.feature_dir,
        output_dir=args.output_dir,
        label_file=args.label_file,
        config=config
    )
    
    print("Starting Graph Construction...")
    count = builder.process()
    print(f"Successfully constructed {count} graphs.")
    
    dataset = WSIGraphDataset(args.output_dir)
    stats = get_data_stats(dataset)
    print("Dataset Statistics:", stats)
