import torch
import torch.nn as nn
import torch.nn.functional as F

# ---- helpers for safe alignment ----
def center_crop_to(x, target_hw):
    _, _, H, W = x.shape
    Ht, Wt = target_hw
    dh = max(0, (H - Ht) // 2)
    dw = max(0, (W - Wt) // 2)
    return x[:, :, dh:dh+Ht, dw:dw+Wt]

def pad_or_crop_to(x, target_hw):
    _, _, H, W = x.shape
    Ht, Wt = target_hw
    # crop if larger
    x = center_crop_to(x, (min(H, Ht), min(W, Wt)))
    _, _, Hc, Wc = x.shape
    pad_h = Ht - Hc
    pad_w = Wt - Wc
    if pad_h > 0 or pad_w > 0:
        top = pad_h // 2
        bottom = pad_h - top
        left = pad_w // 2
        right = pad_w - left
        x = F.pad(x, (left, right, top, bottom), mode='reflect')
    return x

# SE block
class SEBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, max(1, in_channels // reduction), bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(max(1, in_channels // reduction), in_channels, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.global_avg_pool(x).view(b, c)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.LeakyReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.se = SEBlock(out_channels)
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1) \
            if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out = out + identity
        return self.relu(out)

# ResUNet (no output_padding; safe size alignment)
class ResUNet(nn.Module):
    def __init__(self, in_channels=19, out_channels=1):
        super().__init__()
        # Encoder
        self.enc1 = ResidualBlock(in_channels, 16)
        self.enc2 = ResidualBlock(16, 32)
        self.enc3 = ResidualBlock(32, 64)
        self.enc4 = ResidualBlock(64, 128)

        # Decoder: keep transposed convs but remove output_padding
        self.upconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec4 = ResidualBlock(128, 64)

        self.upconv3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec3 = ResidualBlock(64, 32)

        self.upconv2 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2, padding=0)  # ✅ no output_padding
        self.dec2 = ResidualBlock(32, 16)

        self.final_conv = nn.Conv2d(16, out_channels, kernel_size=1)

    def forward(self, x, epoch=None):  # keep epoch to match your trainer
        B, C, H0, W0 = x.shape

        # Encoder
        enc1 = self.enc1(x)                           # H,W
        enc2 = self.enc2(F.max_pool2d(enc1, 2))       # H/2,W/2
        enc3 = self.enc3(F.max_pool2d(enc2, 2))       # H/4,W/4
        enc4 = self.enc4(F.max_pool2d(enc3, 2))       # H/8,W/8

        # Decoder + safe skip alignment (center-crop skip to upsample size)
        dec4 = self.upconv4(enc4)                     # ~H/4,W/4
        dec4 = torch.cat((dec4, center_crop_to(enc3, dec4.shape[-2:])), dim=1)
        dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)                     # ~H/2,W/2
        dec3 = torch.cat((dec3, center_crop_to(enc2, dec3.shape[-2:])), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)                     # typically 136×136 when input is 137×137
        dec2 = torch.cat((dec2, center_crop_to(enc1, dec2.shape[-2:])), dim=1)
        dec2 = self.dec2(dec2)

        out = self.final_conv(dec2)
        out = pad_or_crop_to(out, (H0, W0))           # back to 137×137 exactly
        return out #torch.clamp(out, min=0)

# ---- quick sanity check for your shape ----
if __name__ == "__main__":
    x = torch.randn(16, 19, 137, 137)  # (batch=16, channels=19, H=W=137)
    net = ResUNet(in_channels=19, out_channels=1)
    y = net(x)
    print("Input:", tuple(x.shape), "Output:", tuple(y.shape))  # -> (16, 1, 137, 137)