import os
import math
import numpy as np
import rasterio
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------------------------
# Configuration Parameters
# -------------------------------------------
# Hardware Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# File Paths
INPUT_FILE = ".../LC80410282019070LGN0020190311.tif" # change the input path
MODEL_PATH = "unet_attention_model_combine2.pth"  # Ensure the model file is at this path
OUTPUT_DIR = ".../output" # change the output path

# Normalization parameters (must match training settings)
# Order corresponds to: [SZA, DEM, SVF, cos_i_bi, cos_i_sza, TOA_B1...TOA_B7]
LOWER_PERCENTILE = np.array([23, 100, 0.75, 0, 0, 0, 0, 0, 0, 0, 0, 0])
UPPER_PERCENTILE = np.array([75, 3300, 1, 1, 2, 500, 550, 450, 400, 250, 30, 10])

# Denormalization parameters (for the 7 output reflectance bands)
TARGET_LOWER = np.array([0, 0, 0, 0, 0, 0, 0])
TARGET_UPPER = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])


# -------------------------------------------
# 1. Model Definition
# -------------------------------------------
# Keep original model structure to correctly load weights
class PSPModule(nn.Module):
    def __init__(self, in_channels, sizes=(1, 2, 3, 6)):
        super(PSPModule, self).__init__()
        self.stages = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=(size, size)),
                nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
                nn.BatchNorm2d(in_channels // 4),
                nn.ReLU(inplace=True)
            ) for size in sizes
        ])
        self.conv = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)

    def forward(self, x):
        feats = [stage(x) for stage in self.stages]
        feats = [F.interpolate(feat, size=x.shape[2:], mode='bilinear', align_corners=True) for feat in feats]
        return self.conv(torch.cat([x] + feats, dim=1))


class SpectralAttention(nn.Module):
    def __init__(self, in_channels):
        super(SpectralAttention, self).__init__()
        self.conv1d = nn.Conv1d(in_channels, in_channels, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.shape
        y = x.view(b, c, h * w)
        y = self.conv1d(y)
        y = y.view(b, c, h, w)
        return x * self.sigmoid(y)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv(y)
        return self.sigmoid(y) * x


class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = torch.mean(x, dim=(2, 3), keepdim=True)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y)
        return x * y


class UNetWithAttention(nn.Module):
    def double_conv(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def __init__(self, in_channels=12, out_channels=7, features=[64, 128, 256, 512]):
        super(UNetWithAttention, self).__init__()
        self.encoder_blocks = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        current_channels = in_channels

        for feature in features:
            self.encoder_blocks.append(nn.Sequential(
                self.double_conv(current_channels, feature),
                SpectralAttention(feature)
            ))
            current_channels = feature

        self.bottleneck = nn.Sequential(
            self.double_conv(current_channels, current_channels * 2),
            PSPModule(current_channels * 2)
        )
        self.attention_block = SEBlock(current_channels * 2)

        self.up_transpose = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()
        decoder_in_channels = current_channels * 2
        for feature in reversed(features):
            self.up_transpose.append(
                nn.ConvTranspose2d(decoder_in_channels, feature, kernel_size=2, stride=2)
            )
            self.decoder_blocks.append(nn.Sequential(
                self.double_conv(feature * 2, feature),
                SpatialAttention()
            ))
            decoder_in_channels = feature

        self.final_conv = nn.Conv2d(decoder_in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for enc in self.encoder_blocks:
            x = enc(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        x = self.attention_block(x)

        skip_connections = skip_connections[::-1]
        for idx in range(len(self.up_transpose)):
            x = self.up_transpose[idx](x)
            skip = skip_connections[idx]
            if x.shape != skip.shape:
                x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
            x = torch.cat((skip, x), dim=1)
            x = self.decoder_blocks[idx](x)

        return self.final_conv(x)


# -------------------------------------------
# 2. Helper Functions
# -------------------------------------------
def normalize_data(data, lower_p, upper_p):
    """Min-max normalization using percentiles"""
    norm_data = (data - lower_p[:, None, None]) / (upper_p[:, None, None] - lower_p[:, None, None])
    return norm_data


def denormalize_data(norm_data, lower_p, upper_p):
    """Denormalization"""
    if lower_p.ndim == 1:
        lower_p = lower_p[np.newaxis, :]
        upper_p = upper_p[np.newaxis, :]
    lower = torch.tensor(lower_p[:, :, None, None], dtype=torch.float32, device=norm_data.device)
    upper = torch.tensor(upper_p[:, :, None, None], dtype=torch.float32, device=norm_data.device)
    return norm_data * (upper - lower) + lower


def prepare_input_features(img, shadow_mask_path):
    """
    Extract and calculate model features from the original multi-band image
    Args:
        img: Original image data (C, H, W)
        shadow_mask_path: Path to the shadow mask file
    Returns:
        input_img: Model input features before normalization (12, H, W)
        invalid_mask: Mask for invalid values
    """
    # Extract band information (based on specific data band arrangement)
    # Band 0-6: True Reflectance (Not used as input)
    # Band 7: QA
    # Band 8-14: TOA Radiance
    toa = 0.01 * (img[8:15, :, :].astype(float))

    # Topographic factors
    sza = 90 - 0.01 * (img[15, :, :].astype(float))  # Solar Zenith Angle
    saa = 0.01 * (img[16, :, :].astype(float))  # Solar Azimuth Angle
    dem = img[17, :, :].astype(float)
    slope = 0.01 * (img[18, :, :].astype(float))
    aspect = 0.01 * (img[19, :, :].astype(float))

    # Calculate geometric factors (Cosi, SVF)
    cosi = (np.cos(np.deg2rad(sza)) * np.cos(np.deg2rad(slope)) +
            np.sin(np.deg2rad(sza)) * np.sin(np.deg2rad(slope)) *
            np.cos(np.deg2rad(saa) - aspect))
    cosi[cosi < 0] = 0

    cosi_bi = (cosi > 0).astype(float)

    # Avoid division by zero
    cos_sza_rad = np.cos(np.deg2rad(sza))
    cosi_sza = np.zeros_like(cosi)
    valid_div = cos_sza_rad != 0
    cosi_sza[valid_div] = cosi[valid_div] / cos_sza_rad[valid_div]

    svf = (1 + np.cos(np.deg2rad(slope))) / 2

    # Apply shadow mask correction
    if os.path.exists(shadow_mask_path):
        with rasterio.open(shadow_mask_path) as src:
            shadow_data = src.read(1)
            # Assuming shadow_data == 1 indicates shadow
            cosi_bi[shadow_data == 1] = 0
            cosi_sza[shadow_data == 1] = 0
    else:
        print(f"Warning: Shadow mask not found at {shadow_mask_path}, skipping shadow correction.")

    # Handle invalid values (DEM=0 area)
    invalid_mask = (dem == 0)
    for arr in [toa, sza, cosi_bi, cosi_sza, svf]:
        if arr.ndim == 3:
            arr[:, invalid_mask] = np.nan
        else:
            arr[invalid_mask] = np.nan

    # Stack features: [SZA, DEM, SVF, Cosi_bi, Cosi_sza, TOA(7 bands)] -> Total 12 channels
    input_features = np.concatenate([
        sza[np.newaxis, :, :],
        dem[np.newaxis, :, :],
        svf[np.newaxis, :, :],
        cosi_bi[np.newaxis, :, :],
        cosi_sza[np.newaxis, :, :],
        toa
    ], axis=0)

    # Fill NaNs with -1 (Consistent with training preprocessing)
    input_features = np.nan_to_num(input_features, nan=-1)

    return input_features, invalid_mask


# -------------------------------------------
# 3. Main Execution
# -------------------------------------------
def main():
    print(f"Running on device: {DEVICE}")

    # 1. Load Model
    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")

    print("Loading model...")
    model = UNetWithAttention(in_channels=12, out_channels=7).to(DEVICE)
    # Use weights_only=True for enhanced security
    state_dict = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=True)
    model.load_state_dict(state_dict)
    model.eval()

    # 2. Prepare Input File Paths
    if not os.path.exists(INPUT_FILE):
        raise FileNotFoundError(f"Input file not found: {INPUT_FILE}")

    input_dir = os.path.dirname(INPUT_FILE)
    base_name = os.path.splitext(os.path.basename(INPUT_FILE))[0]
    shadow_file = os.path.join(input_dir, f"shadow_{base_name}.tif")

    print(f"Processing: {base_name}")

    # 3. Read and Preprocess Data
    with rasterio.open(INPUT_FILE) as src:
        img_data = src.read()
        profile = src.profile

    # Calculate input features
    input_arr, invalid_mask = prepare_input_features(img_data, shadow_file)

    # Normalize
    norm_input = normalize_data(input_arr, LOWER_PERCENTILE, UPPER_PERCENTILE)

    # Convert to Tensor
    input_tensor = torch.from_numpy(norm_input).float().unsqueeze(0).to(DEVICE)

    # 4. Model Inference
    print("Running inference...")
    with torch.no_grad():
        # Automatic mixed precision, depending on hardware support
        with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu"):
            output_norm = model(input_tensor)
            output_norm = output_norm.to(dtype=torch.float32)  # Ensure denormalization happens in float32

            # Denormalize and convert back to Numpy
            estimated_ref = denormalize_data(output_norm, TARGET_LOWER, TARGET_UPPER)
            estimated_ref = estimated_ref.cpu().squeeze(0).numpy()

    # 5. Post-processing and Saving
    # Apply invalid value mask
    estimated_ref[:, invalid_mask] = 0

    # Convert data type to int16 (Scaling factor: 10000)
    estimated_ref_int = (estimated_ref * 10000).astype(np.int16)

    # Prepare output path
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    output_path = os.path.join(OUTPUT_DIR, f"{base_name}_estimated_reflectance.tif")

    # Update Profile
    profile.update(
        dtype='int16',
        count=7,  # Save only the 7 estimated bands
        compress='lzw'
    )

    print(f"Saving result to: {output_path}")
    with rasterio.open(output_path, 'w', **profile) as dst:
        dst.write(estimated_ref_int)
        # Optional: Write band descriptions
        for i in range(7):
            dst.set_band_description(i + 1, f'Estimated_SR_B{i + 1}')

    print("Done.")


if __name__ == "__main__":
    main()