import torch
import torch.nn as nn
import torch.nn.functional as F
from ConfigV1 import *

# ---------- helpers: padding + kernels ----------

def _pad_reflect(x, pad):
    # pad is (left,right,top,bottom)
    l, r, t, b = pad
    if l or r or t or b:
        x = F.pad(x, (l, r, t, b), mode='reflect')
    return x

def _gaussian_1d(win, sigma, device, dtype):
    ax = torch.arange(win, device=device, dtype=dtype) - (win - 1) / 2.0
    g = torch.exp(-0.5 * (ax / sigma) ** 2)
    g = g / g.sum()
    return g

def _gaussian_2d(win=7, sigma=1.5, device='cpu', dtype=torch.float32):
    g1 = _gaussian_1d(win, sigma, device, dtype)
    g2 = torch.outer(g1, g1)
    return g2 / g2.sum()

# ---------- SSIM (reflect padding + Gaussian window) ----------

def ssim_loss(x, y, win=7, sigma=1.5, C1=0.01**2, C2=0.03**2):
    """
    1 - SSIM map for single-channel tensors in [0,1].
    Uses Gaussian window and reflect padding (no zero edges).
    """
    B, C, H, W = x.shape
    assert C == 1, "ssim_loss expects (B,1,H,W)"
    pad = win // 2

    w = _gaussian_2d(win=win, sigma=sigma, device=x.device, dtype=x.dtype)
    w = w.view(1, 1, win, win)

    # reflect pad first, then valid conv (padding=0) to avoid zero seams
    x_pad = _pad_reflect(x, (pad, pad, pad, pad))
    y_pad = _pad_reflect(y, (pad, pad, pad, pad))

    mu_x = F.conv2d(x_pad, w, padding=0)
    mu_y = F.conv2d(y_pad, w, padding=0)

    mu_x2 = mu_x * mu_x
    mu_y2 = mu_y * mu_y
    mu_xy = mu_x * mu_y

    x2 = x * x
    y2 = y * y
    xy = x * y

    x2_pad = _pad_reflect(x2, (pad, pad, pad, pad))
    y2_pad = _pad_reflect(y2, (pad, pad, pad, pad))
    xy_pad = _pad_reflect(xy, (pad, pad, pad, pad))

    sigma_x = F.conv2d(x2_pad, w, padding=0) - mu_x2
    sigma_y = F.conv2d(y2_pad, w, padding=0) - mu_y2
    sigma_xy = F.conv2d(xy_pad, w, padding=0) - mu_xy

    ssim = ((2 * mu_xy + C1) * (2 * sigma_xy + C2)) / \
           ((mu_x2 + mu_y2 + C1) * (sigma_x + sigma_y + C2) + 1e-12)

    # Map is already aligned to input size due to reflect padding -> valid region
    return 1 - ssim.clamp(0, 1)

# ---------- Sobel (reflect padding) ----------

def sobel_grad(img):
    kx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=img.dtype, device=img.device).view(1,1,3,3)
    ky = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=img.dtype, device=img.device).view(1,1,3,3)
    img_pad = _pad_reflect(img, (1,1,1,1))         # avoid zero borders
    gx = F.conv2d(img_pad, kx, padding=0)
    gy = F.conv2d(img_pad, ky, padding=0)
    return torch.sqrt(gx*gx + gy*gy + 1e-12)

# ---------- bounded intensity (unchanged) ----------

def bounded_intensity_weight(tgt, w_min=1.0, w_max=6.0, k=8.0, mid=0.5):
    s = torch.sigmoid(k * (tgt - mid))
    return w_min + (w_max - w_min) * s

def charbonnier(x, eps=1e-3):  # robust L1
    return torch.sqrt(x*x + eps*eps)

def huber(x, delta=1.0):
    abs_x = x.abs()
    quadratic = torch.clamp(abs_x, max=delta)
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 / delta + linear


class EdgeAwareWeightedLoss(nn.Module):
    def __init__(self,
                 λ_l1=1.0,
                 λ_grad=0.7,
                 λ_ssim=0.3,
                 λ_wmse=0.0,
                 wmse_alpha=8.0,
                 wmse_base=1.0,
                 wmse_max_weight=50.0,

                 use_intensity_weight=False,
                 inten_w_min=1.0,
                 inten_w_max=6.0,
                 inten_k=8.0,
                 boundary_band_px=3,
                 boundary_boost=3.0,
                 ssim_win=7, ssim_sigma=1.5,
                 weight_normalize=True,
                 weight_eps=1e-6):
        super().__init__()
        self.λ_l1   = λ_l1
        self.λ_grad = λ_grad
        self.λ_ssim = λ_ssim

        self.λ_wmse      = λ_wmse
        self.wmse_alpha  = float(wmse_alpha)
        self.wmse_base   = float(wmse_base)
        self.wmse_max_w  = wmse_max_weight

        self.use_intensity_weight = use_intensity_weight
        self.inten_w_min = inten_w_min
        self.inten_w_max = inten_w_max
        self.inten_k = inten_k
        self.boundary_band_px = boundary_band_px
        self.boundary_boost = boundary_boost
        self.ssim_win = ssim_win
        self.ssim_sigma = ssim_sigma
        self.weight_normalize = weight_normalize
        self.weight_eps = weight_eps

    def _edge_ring(self, tgt):
        mask = (tgt > 0).float()
        r = self.boundary_band_px
        dil = F.max_pool2d(mask, 2*r+1, stride=1, padding=r)
        ero = 1 - F.max_pool2d(1 - mask, 2*r+1, stride=1, padding=r)
        ring = (dil - ero).clamp(min=0.0)
        return 1.0 + self.boundary_boost * ring

    def forward(self, predictions, targets, spatial_weights):
        # (B,1,H,W)
        if predictions.dim() == 3: predictions = predictions.unsqueeze(1)
        if targets.dim() == 3:     targets     = targets.unsqueeze(1)

        l1 = huber(predictions - targets, delta=1.5)
        gpred, gtgt = sobel_grad(predictions), sobel_grad(targets)
        lgrad = (gpred - gtgt).abs()
        lssim = ssim_loss(predictions, targets, win=self.ssim_win, sigma=self.ssim_sigma)

        # compose shared spatial/edge/intensity weights -> W
        W = 1.0
        if spatial_weights is not None:
            if spatial_weights.dim() == 2:
                spatial_weights = spatial_weights.unsqueeze(0).unsqueeze(0)
            elif spatial_weights.dim() == 3:
                spatial_weights = spatial_weights.unsqueeze(1)
            W = W * spatial_weights

        # boundary emphasis from targets (no grad through mask)
        W = W * self._edge_ring(targets.detach())

        # optional bounded intensity emphasis already in your design
        if self.use_intensity_weight:
            W = W * bounded_intensity_weight(targets.detach(),
                                             w_min=self.inten_w_min,
                                             w_max=self.inten_w_max,
                                             k=self.inten_k)

        # normalize shared weight map if requested
        if self.weight_normalize:
            mean_w = W.mean().clamp_min(self.weight_eps).detach()
            W = W / mean_w

        # apply W to existing terms
        l1    = (W * l1).mean()
        lgrad = (W * lgrad).mean()
        lssim = (W * lssim).mean()

        # ---------- exponential intensity-weighted MSE term ----------
        l_wmse = 0.0
        if self.λ_wmse != 0.0:
            squared_diff = (predictions - targets) ** 2

            # intensity weights (no grad through targets)
            inten_w = self.wmse_base + torch.exp(self.wmse_alpha * targets.detach())
            W_wmse = W * inten_w  # reuse spatial/edge weighting

            # cap (optional) to prevent exploding gradients
            if self.wmse_max_w is not None:
                W_wmse = torch.clamp(W_wmse, max=self.wmse_max_w)

            # keep this term on a comparable scale if normalization is enabled
            if self.weight_normalize:
                mean_wmse = W_wmse.mean().clamp_min(self.weight_eps).detach()
                W_wmse = W_wmse / mean_wmse

            l_wmse = (W_wmse * squared_diff).mean()

        # total
        return self.λ_l1*l1 + self.λ_grad*lgrad + self.λ_ssim*lssim + self.λ_wmse*l_wmse