import torch
import torch.nn as nn
from ConfigV1 import *
'''
class WeightedMSELoss(nn.Module):
    def __init__(self):
        super(WeightedMSELoss, self).__init__()
    
    def __call__(self, predictions, targets, weights):
        """
        Compute the weighted MSE loss.
        
        Args:
            predictions (torch.Tensor): Predicted values.
            targets (torch.Tensor): Ground truth values.
            weights (torch.Tensor): Weights for each sample/element.
        
        Returns:
            torch.Tensor: Weighted MSE loss.
        """
        # Ensure weights are broadcastable to predictions and targets
        squared_diff = (predictions - targets) ** 2
        weighted_squared_diff = weights * squared_diff
        return torch.mean(weighted_squared_diff)
'''

class WeightedMSELoss(nn.Module):
    def __init__(self, alpha=8, base=1.0, max_weight=50):
        """
        alpha: controls how strongly weights grow with dBZ intensity
        base: minimum weight
        max_weight: optional cap to avoid exploding gradients
        """
        super().__init__()
        self.alpha = alpha
        self.base = base
        self.max_weight = max_weight
        # Print the attributes
        print(f"alpha={self.alpha}, base={self.base}, max_weight={self.max_weight}")
    
    def forward(self, predictions, targets, spatial_weights):
        # Pixelwise squared error
        squared_diff = (predictions - targets) ** 2
        
        # Intensity-based weighting (same shape as targets)
        intensity_weights = self.base + torch.exp(self.alpha * targets)
        
        # Combine spatial + intensity
        weights = spatial_weights * intensity_weights
        
        if self.max_weight is not None:
            weights = torch.clamp(weights, max=self.max_weight)
        
        weighted_squared_diff = weights * squared_diff
        return weighted_squared_diff.mean()
