import numpy as np
from pathlib import Path
from torchmetrics import StructuralSimilarityIndexMeasure
import torch


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def create_loss_weights(dbzh, alpha=4.0, beta=0.1, device="cpu"):
    """
    Create spatial weights for loss function when predicting radar reflectivity (dBZ).

    Args:
        dbzh   : numpy array or torch tensor, shape [time, x, y] or [x, y]
        alpha  : scaling factor for storm intensities (default: 4.0)
        beta   : weight for background (non-precipitation) pixels (default: 0.1)
        device : 'cpu' or 'cuda'

    Returns:
        torch.Tensor of shape [x, y] with weights in range [beta, 1+alpha]
    """
    # Convert to tensor if needed
    if isinstance(dbzh, np.ndarray):
        dbzh = torch.tensor(dbzh, dtype=torch.float32, device=device)
    else:
        dbzh = dbzh.to(device, dtype=torch.float32)

    # Case 1: 3D array (time, x, y) → collapse to 2D field
    if dbzh.ndim == 3:
        mask = (dbzh > 0).any(dim=0).float()             # 1 where precip exists in any timestep
        dbz_field = dbzh.max(dim=0).values               # take maximum intensity over time
    # Case 2: 2D array (x, y)
    elif dbzh.ndim == 2:
        mask = (dbzh > 0).float()
        dbz_field = dbzh
    else:
        raise ValueError("dbzh must be 2D [x,y] or 3D [time,x,y]")

    # Normalize intensities (avoid divide-by-zero)
    norm_dbz = dbz_field / (dbz_field.max() + 1e-6)

    # Hybrid weighting: background gets 'beta', precip scaled by intensity
    weights = mask * (1 + alpha * norm_dbz) + (1 - mask) * beta

    # Optional: normalize to [0,1] range
    weights = (weights - weights.min()) / (weights.max() - weights.min() + 1e-6)

    return weights



def min_max_normalization(data_type, data1, data2, min_val=0.0, max_val=1.0):
    """
    Min-max normalization of a given dataset.

    Args:
        data_type (string):  The data type to normalize: training, test, validation
        data1 (numpy array): Input data to extract min and max (i.e. the training data).
        data2 (numpy array): The data to be normalized. In case of training dataset data1 == data2.

        min_val (float): Minimum value after normalization.
        max_val (float): Maximum value after normalization.

    If training Returns:    
        numpy array: normalized_data. 
        float      : data_min, data_max.
    else
        numpy array: normalized_data.
    """
    
    data_min = np.min(data1)
    data_max = np.max(data1)

    # Check if data_min or data_max is NaN
    if np.isnan(data_min) or np.isnan(data_max):
        raise ValueError("Error: data_min or data_max is NaN. Please check the input data for invalid values.")

    # Check if data_max is equal to data_min (to avoid division by zero)
    if data_max == data_min:
        raise ValueError("Error: data_max is equal to data_min. Cannot normalize with zero range.")
    
    normalized_data = (data2 - data_min) / (data_max - data_min)
    normalized_data = normalized_data * (max_val - min_val) + min_val
    if data_type == 'training':
        return normalized_data, data_min, data_max
    else:
        return normalized_data
        

def min_max_denormalization(normalized_data, data_min, data_max, min_val=0.0, max_val=1.0):
    """
    Min-max denormalization of a given dataset.

    Args:
        normalized_data (numpy array): Input data to be denormalized.
        data_min (float): min value computed in the normalization process.
        data_max (float): max value computed in the normalization process.
        min_val (float): Minimum value after normalization.
        max_val (float): Maximum value after normalization.

    Returns:
        numpy array: denormalized_data.
    """

    denormalized_data = (normalized_data - min_val)/(max_val-min_val) * (data_max - data_min) + data_min
    
    return denormalized_data


import numpy as np

def log_normalization(data_type, data1, data2, epsilon=1e-8):
    """
    Log normalization of a given dataset.

    Args:
        data_type (string):  The data type to normalize: training, test, validation.
        data1 (numpy array): Input data to extract min and max (usually training data).
        data2 (numpy array): The data to be normalized. In case of training dataset, data1 == data2.
        epsilon (float): Small constant added to avoid log(0).

    If training Returns:
        numpy array: normalized_data.
        float      : data_min, data_max.
    else:
        numpy array: normalized_data.
    """
    # Add epsilon to avoid issues with zero or negative values
    data1_safe = data1 + epsilon
    data2_safe = data2 + epsilon

    if np.any(data1_safe <= 0):
        raise ValueError("Log normalization requires all input values to be positive after adding epsilon.")

    log_data1 = np.log(data1_safe)
    log_data2 = np.log(data2_safe)

    data_min = np.min(log_data1)
    data_max = np.max(log_data1)

    if np.isnan(data_min) or np.isnan(data_max):
        raise ValueError("Error: data_min or data_max is NaN. Please check the input data for invalid values.")

    if data_max == data_min:
        raise ValueError("Error: data_max is equal to data_min. Cannot normalize with zero range.")

    normalized_data = (log_data2 - data_min) / (data_max - data_min)

    if data_type == 'training':
        return normalized_data, data_min, data_max
    else:
        return normalized_data


def log_denormalization(normalized_data, data_min, data_max, epsilon=1e-8):
    """
    Log denormalization of a given dataset.

    Args:
        normalized_data (numpy array): Input data to be denormalized.
        data_min (float): Minimum log-transformed value from training data.
        data_max (float): Maximum log-transformed value from training data.
        epsilon (float): Small constant added during normalization.

    Returns:
        numpy array: denormalized_data (original scale).
    """
    log_data = normalized_data * (data_max - data_min) + data_min
    denormalized_data = np.exp(log_data) - epsilon
    return denormalized_data



def count_parameters_by_layer(model):
    '''
    Dictionary to hold the layer names and their parameter counts.
    
    Args:
        model: the initialized ML model. 

    Returns:
        layer_params (dictionary).
        total_params (int).    
    '''
    
    layer_params = {}
    total_params = 0  # Initialize total parameter count
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            layer_params[name] = num_params
            total_params += num_params  # Accumulate the total
    
    return layer_params, total_params


def create_dir(dir):
    ''' 
    Create the directory defined in dir. If dir is a path cretes all parent directories.

    Args:
        dir (string): directory name or path.
    '''
    directory = Path(dir)
    # Check if the output directory exists. If not create it.
    if not directory.exists():
        directory.mkdir(parents=True, exist_ok=True)


def filter_precipitation_timesteps(dbzh, variable, dbz_threshold=13.5, precipitation_percentage=50):
    # Calculate total spatial points per timestep
    total_points = dbzh.shape[1] * dbzh.shape[2]
    
    # Find timesteps with precipitation based on dbzh reflectivity values
    selected_timesteps = []
    for t in range(dbzh.shape[0]):  # Iterate over each timestep
        # Calculate number of points exceeding the precipitation threshold
        precip_points = np.sum(dbzh[t] >= dbz_threshold)
        
        # Calculate the percentage of precipitating points
        percentage_precip = (precip_points / total_points) * 100
        
        # Select timestep if percentage_precip is above the specified threshold
        if percentage_precip >= precipitation_percentage:
            selected_timesteps.append(t)
    
    # Filter both arrays based on selected timesteps
    dbzh_filtered = dbzh[selected_timesteps, :, :]
    variable_filtered = variable[selected_timesteps, :, :]
    
    return dbzh_filtered, variable_filtered, selected_timesteps





    
