import torch
import xarray as xr
import numpy as np
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchmetrics import StructuralSimilarityIndexMeasure
from skimage.metrics import structural_similarity as ssim


from Utils import *
from MultichannelDataset import MultichannelDataset
#from UNet import UNet
from ResUNet import ResUNet

from Trainer import train_model
from ConfigV1 import *
import os
import gc


print('@@@@@ SUMMARY OF THE ARCHITECTURE @@@@@ \n', flush=True)

print(f'Model structure --> {model}',      flush=True)
print(f'Optimizer       --> {optimizer}',  flush=True)
print(f'Loss Function   --> {criterion}',  flush=True)
print(f'Batch size      --> {batch_size}', flush=True)
print(f'Initial lr      --> {lr}',         flush=True)
print(f'Epochs          --> {num_epochs}', flush=True)
print(f'Model name      --> {model_name}', flush=True)
print(f'Channels        --> {in_channels}',flush=True)

print('######################################################################################', flush=True)


# OPEN DATASETS
print('\n@@@@@ DATA @@@@@ \n', flush=True)

print('Loading datasets...', flush=True)
print('######################################################################################', flush=True)
aladin = xr.open_dataset(aladin, chunks={"time": 100})
radar  = xr.open_dataset(radar, chunks={"time": 100})


if years_selection:
    print(f'\n@Getting years {desired_years}',flush=True)
    aladin = aladin.where(aladin['time'].dt.year.isin(desired_years), drop=True)
    radar  = radar.where(radar['time'].dt.year.isin(desired_years), drop=True)

if hours_selection:
    print(f'\n@Getting times {desired_hours}',flush=True)
    aladin = aladin.where(aladin['time'].dt.hour.isin(desired_hours), drop=True)
    radar  = radar.where(radar['time'].dt.hour.isin(desired_hours), drop=True)



# GET VARIABLES
print('\nGetting variables...', flush=True)
print('######################################################################################', flush=True)

# Dictionary to store the variables
variables = {}

for i, var in enumerate(var_names):
    print(f'{i} - Loading {var}')
        
    if var == 'q': 
        
        if qLevels != 0: # Set number of q levels in config file
            print(f'Available P levels    --> {aladin.isobaricInhPa.values} hPa')
            print(f'Selected levels for q --> {aladin.isobaricInhPa[0:qLevels].values} hPa')
            for j in range(qLevels):
                variables[f'{var}{j}'] = aladin[var].values[:,j,:,:]
                variables[f'{var}{j}'][variables[f'{var}{j}']<0] = 0.0

    elif var == 't':
        if tLevels != 0: # Set number of t levels in config file
            print(f'Selected levels for t --> {aladin.isobaricInhPa[0:tLevels].values} hPa')
            for j in range(tLevels):
                variables[f'{var}{j}'] = aladin[var].values[:,j,:,:]    

    elif var == 'v':
        if vLevels != 0: # Set number of v levels in config file
            print(f'Selected levels for v --> {aladin.isobaricInhPa[0:vLevels].values} hPa')
            for j in range(vLevels):
                variables[f'{var}{j}'] = aladin[var].values[:,j,:,:] 

    elif var == 'u':
        if uLevels != 0: # Set number of u levels in config file
            print(f'Selected levels for u --> {aladin.isobaricInhPa[0:uLevels].values} hPa')
            for j in range(uLevels):
                variables[f'{var}{j}'] = aladin[var].values[:,j,:,:] 

    elif var == 'r':
        if rLevels != 0: # Set number of u levels in config file
            print(f'Selected levels for r --> {aladin.isobaricInhPa[0:rLevels].values} hPa')
            for j in range(rLevels):
                variables[f'{var}{j}'] = aladin[var].values[:,j,:,:] 
    
    elif var == 'r2m':
        variables[f'{var}'] = aladin[var].values
        variables[f'{var}'][variables[f'{var}']<0] = 0.0
        
    else:
        variables[f'{var}'] = aladin[var].values

print("\n@ Variables:", flush = True)
for key in variables.keys():
    print(key, flush = True)


dbzh = radar.DBZH.values
time = radar.time.values

del aladin, radar
gc.collect()

# SINCE SOME MISSING DATA, THE RADAR HAS HOLES ON TIME DIMENSION
# WHICH BECAME 0 DURING THE DAILY AVERAGE. IT IS NECESSARY TO REMOVE
# THOSE TIMESTEP

print('\nConvert NaN to 0, create weights for loss and select precipitation timesteps...', flush=True)
print('######################################################################################', flush=True)

print('\n@ Initial shapes')
for key in variables.keys():
    print(f'{key.ljust(8)} shape --> {str(variables[key].shape).rjust(20)}', flush = True)

print(f'dbzh     shape --> {str(dbzh.shape).rjust(20)}', flush=True)



dbzh = np.nan_to_num(dbzh, copy=False, nan=0.0)



print('\n@ Create weights for loss function')

# Create a weight tensor of shape (x, y) where each element is 1 if there is at least 1 element of DBZH != 0 along time dimension, else 0
weights = (dbzh != 0).any(axis=0).astype(int)

# Example setup (assuming `weights` from your previous code)
x_dim, y_dim = weights.shape  # Get dimensions of the weight map
center_x, center_y = (x_dim // 2), y_dim // 2  # Define the center of the circle
sigma = 45  # Standard deviation controlling the decay

# Create a grid of coordinates
xx, yy = np.meshgrid(np.arange(x_dim), np.arange(y_dim), indexing="ij")

# Calculate Euclidean distance from the center
distances = np.sqrt((xx - center_x) ** 2 + (yy - center_y) ** 2)

# Apply Gaussian decay
gaussian_weights = np.exp(-distances**2 / (2 * sigma**2))

# --- MODIFIED LOGIC TO REPRODUCE THE SHAPE ---

# 1. Define the radius of the inner circle
print(f'RADIUS is {radius}', flush=True)
# 2. Use np.where to create the sharp boundary
# - If distance < radius, set weight to 0.9
# - Otherwise, use the pre-calculated gaussian_weights
final_weights = np.where(distances < radius, 0.9, gaussian_weights)
LOSSweights = torch.tensor(final_weights, device=device)


# Plot the weight map
plt.figure(figsize=(6, 6),dpi=200)
plt.imshow(final_weights, origin='lower', cmap='tab20', vmin=0)
plt.colorbar(label='Weight')
plt.title("Weight Map with Gaussian Decay")
plt.xlabel("x")
plt.ylabel("y")
plt.savefig(f'{fig_dir}/LOSSweights.png', dpi=200, bbox_inches='tight')
print(f'Loss weight image saved in:\n{fig_dir}')



########################################################################################################################

for key in variables.keys():
    dbzh1, variables[key], selected_timesteps = filter_precipitation_timesteps(dbzh, variables[key], dbz_threshold, precipitation_percentage)

time = time[selected_timesteps]
dbzh = dbzh1

print(f'\n@ Final shapes \n  after selecting dbzh values with a percentage of covering greater than {precipitation_percentage}%')
for key in variables.keys():
    print(f'{key.ljust(8)} shape --> {str(variables[key].shape).rjust(20)}', flush = True)
print(f'dbzh     shape --> {str(dbzh.shape).rjust(20)}', flush=True)
print(f'time     shape --> {str(time.shape).rjust(20)}', flush=True)

# SPLIT IN TRAIN AND TEST
print('\nSplit the data in train, test and val (shuffle = True) ...', flush=True)  
print('######################################################################################', flush=True) 

# Set parameters
train_size = 0.8
val_size   = 0.1
test_size  = 0.1

inputs = np.stack([variables[var] for var in variables.keys()], axis=-1)
print('\n@ Shape of stacked variables array', flush=True)
print(inputs.shape, flush=True)


# Split into training (80%) and remaining (20%) sets
inputs_train, inputs_remain, dbzh_train, dbzh_remain, time_train, time_remain = train_test_split(
    inputs, dbzh, time, test_size=(1 - train_size), random_state=seed, shuffle=shuffle
)

# Split the remaining data into validation (10%) and test (10%)
val_ratio = val_size / (val_size + test_size)  # This is effectively 0.5 here
inputs_val, inputs_test, dbzh_val, dbzh_test, time_val, time_test = train_test_split(
    inputs_remain, dbzh_remain, time_remain, test_size=(1 - val_ratio), random_state=seed, shuffle=shuffle
)

high_dbzh_threshold = 13.5

print('\n@ Train, Val, Test shapes')
print(f'inputs_train shape  --> {inputs_train.shape}',   flush=True)
print(f'inputs_val   shape  --> {inputs_val.shape}',     flush=True)
print(f'inputs_test  shape  --> {inputs_test.shape}',    flush=True)
print(f'dbzh_train   shape  --> {dbzh_train.shape}',     flush=True)
print(f'dbzh_val     shape  --> {dbzh_val.shape}',       flush=True)
print(f'dbzh_test    shape  --> {dbzh_test.shape}',      flush=True)
print(f'time_train   shape  --> {time_train.shape}',     flush=True)
print(f'time_val     shape  --> {time_val.shape}',       flush=True)
print(f'time_test    shape  --> {time_test.shape}',      flush=True)


######## AUGMENTATION #############

# -----------------------------
# STEP 2. Identify rare events
# -----------------------------

rare_timesteps_idx = np.any(dbzh_train.reshape(dbzh_train.shape[0], -1) >= high_dbzh_threshold, axis=1)

inputs_rare = inputs_train[rare_timesteps_idx]
dbzh_rare   = dbzh_train[rare_timesteps_idx]

print(f"Rare event timesteps found: {inputs_rare.shape[0]} / {inputs_train.shape[0]}")

# -----------------------------
# STEP 3. Augment rare events
# -----------------------------
def rotate_batch(data, k):
    # Rotate over spatial dims (x, y)
    return np.rot90(data, k=k, axes=(1, 2))

aug_inputs = [inputs_train]
aug_dbzh   = [dbzh_train]

for k in range(4):  # 0°, 90°, 180°, 270°
    rotated_inputs = rotate_batch(inputs_rare, k)
    rotated_dbzh   = rotate_batch(dbzh_rare, k)
    aug_inputs.append(rotated_inputs)
    aug_dbzh.append(rotated_dbzh)

    # Optional: add flips for each rotation
    #flipped_inputs = np.flip(rotated_inputs, axis=1)  # horizontal flip
    #flipped_dbzh   = np.flip(rotated_dbzh, axis=1)
    #aug_inputs.append(flipped_inputs)
    #aug_dbzh.append(flipped_dbzh)

    #flipped_inputs = np.flip(rotated_inputs, axis=2)  # vertical flip
    #flipped_dbzh   = np.flip(rotated_dbzh, axis=2)
    #aug_inputs.append(flipped_inputs)
    #aug_dbzh.append(flipped_dbzh)


# Concatenate all augmented data
inputs_train_aug = np.concatenate(aug_inputs, axis=0)
dbzh_train_aug   = np.concatenate(aug_dbzh, axis=0)

from sklearn.utils import shuffle
# Shuffle augmented training set
inputs_train_aug, dbzh_train_aug = shuffle(inputs_train_aug, dbzh_train_aug, random_state=seed)

print(f"Final training shape after oversampling + augmentation: {inputs_train_aug.shape}, {dbzh_train_aug.shape}")


inputs_train = inputs_train_aug
dbzh_train  = dbzh_train_aug


del dbzh1, dbzh
gc.collect()


print(f'\nNormalize input using min-max and dbzh using {output_normalization_type} normalization ...', flush=True)
print('######################################################################################', flush=True)

# INPUT normalization setup (always min-max)
input_normalize_fn = min_max_normalization
input_min_dict = {}
input_max_dict = {}
input_min_suffix = 'min'
input_max_suffix = 'max'

# OUTPUT normalization setup (user can choose)
if output_normalization_type == 'log':
    output_normalize_fn = log_normalization
    output_min_suffix = 'log_min'
    output_max_suffix = 'log_max'
elif output_normalization_type == 'minmax':
    output_normalize_fn = min_max_normalization
    output_min_suffix = 'min'
    output_max_suffix = 'max'
else:
    raise ValueError(f"Unsupported output normalization type: {output_normalization_type}")

# Split input variables
variables_train = {}
variables_val   = {}
variables_test  = {}

for k, key in enumerate(variables.keys()):
    variables_train[f'{key}_train'] = inputs_train[:, :, :, k]
    variables_val[f'{key}_val']     = inputs_val[:, :, :, k]
    variables_test[f'{key}_test']   = inputs_test[:, :, :, k]

variables_train_norm = {}
variables_val_norm   = {}
variables_test_norm  = {}

# Normalize inputs (always min-max)
for var, key_train, key_val, key_test in zip(variables.keys(), variables_train.keys(), variables_val.keys(), variables_test.keys()):
    variables_train_norm[f'{key_train}_norm'], input_min_dict[f'{var}_{input_min_suffix}'], input_max_dict[f'{var}_{input_max_suffix}'] = \
        input_normalize_fn('training', variables_train[key_train], variables_train[key_train])

    variables_val_norm[f'{key_val}_norm'] = input_normalize_fn('validation', variables_train[key_train], variables_val[key_val])
    variables_test_norm[f'{key_test}_norm'] = input_normalize_fn('test', variables_train[key_train], variables_test[key_test])

# Normalize output (e.g. dbzh)
dbzh_train_norm, dbzh_min, dbzh_max = output_normalize_fn('training', dbzh_train, dbzh_train)
dbzh_val_norm = output_normalize_fn('validation', dbzh_train, dbzh_val)
dbzh_test_norm = output_normalize_fn('test', dbzh_train, dbzh_test)

# Save
if save_datasets_norm:
    print(f'\nSaving input min-max and dbzh {output_normalization_type} dictionaries...', flush=True)
    torch.save(input_min_dict, f'{data_dir}/variables_min.pth')
    torch.save(input_max_dict, f'{data_dir}/variables_max.pth')
    torch.save(dbzh_min, f'{data_dir}/dbzh_{output_min_suffix}.pth')
    torch.save(dbzh_max, f'{data_dir}/dbzh_{output_max_suffix}.pth')

# Print input stats
print('\n@ Input (min-max) not normalized values')
for key in variables.keys():
    print(f'{key.ljust(20)} --> ({str(input_min_dict[f"{key}_min"]).ljust(13)}, {input_max_dict[f"{key}_max"]})')
print(f'dbzh                 --> ({str(dbzh_min).ljust(13)}, {dbzh_max})', flush=True)

print('\n@ Normalized Train')
for key_norm_train in variables_train_norm.keys():
    print(f'{key_norm_train.ljust(20)} --> ({str(np.nanmin(variables_train_norm[key_norm_train])).ljust(13)}, {np.nanmax(variables_train_norm[key_norm_train])})')
print(f'dbzh_train_norm      --> ({str(np.nanmin(dbzh_train_norm)).ljust(13)}, {np.nanmax(dbzh_train_norm)})', flush=True)

print('\n@ Normalized Validation')
for key_norm_val in variables_val_norm.keys():
    print(f'{key_norm_val.ljust(20)} --> ({str(np.nanmin(variables_val_norm[key_norm_val])).ljust(13)}, {np.nanmax(variables_val_norm[key_norm_val])})')
print(f'dbzh_val_norm        --> ({str(np.nanmin(dbzh_val_norm)).ljust(13)}, {np.nanmax(dbzh_val_norm)})', flush=True)

print('\n@ Normalized Test')
for key_norm_test in variables_test_norm.keys():
    print(f'{key_norm_test.ljust(20)} --> ({str(np.nanmin(variables_test_norm[key_norm_test])).ljust(13)}, {np.nanmax(variables_test_norm[key_norm_test])})')
print(f'dbzh_test_norm       --> ({str(np.nanmin(dbzh_test_norm)).ljust(13)}, {np.nanmax(dbzh_test_norm)})', flush=True)

# Stack input
inputs_train_norm = np.stack([variables_train_norm[var] for var in variables_train_norm.keys()], axis=-1)
inputs_val_norm   = np.stack([variables_val_norm[var] for var in variables_val_norm.keys()], axis=-1)
inputs_test_norm  = np.stack([variables_test_norm[var] for var in variables_test_norm.keys()], axis=-1)

print('\n@ Shape of stacked variables normalized array', flush=True)
print(f'Training                --> {inputs_train_norm.shape}', flush=True)
print(f'Valildation             --> {inputs_val_norm.shape}', flush=True)
print(f'Test                    --> {inputs_test_norm.shape}', flush=True)
print(f'dbzh_train shape        --> {dbzh_train_norm.shape}', flush=True)
print(f'dbzh_val shape          --> {dbzh_val_norm.shape}', flush=True)
print(f'dbzh_test shape         --> {dbzh_test_norm.shape}', flush=True)


print('\nCreate train, validation and test datasets...', flush=True)  
print('######################################################################################', flush=True)  
dataset_train = MultichannelDataset(inputs_train_norm, dbzh_train_norm)
dataset_test  = MultichannelDataset(inputs_test_norm, dbzh_test_norm)
dataset_val   = MultichannelDataset(inputs_val_norm, dbzh_val_norm)

if save_datasets:
    print('\nSaving train, validation and test datasets...', flush=True)  
    torch.save(dataset_train, f'{data_dir}/dataset_train.pth', pickle_protocol=4)
    torch.save(dataset_test, f'{data_dir}/dataset_test.pth', pickle_protocol=4)
    torch.save(dataset_val, f'{data_dir}/dataset_val.pth', pickle_protocol=4)


print('\nCreate train, validation and test dataloaders...', flush=True)  
print('######################################################################################', flush=True)       
# Create dataloaders
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g)
test_loader  = DataLoader(dataset_test, batch_size=batch_size,  shuffle=False, worker_init_fn=seed_worker, generator=g)
val_loader   = DataLoader(dataset_val, batch_size=batch_size,   shuffle=False, worker_init_fn=seed_worker, generator=g)


del inputs_train_norm, inputs_val_norm, inputs_test_norm, dataset_train, dataset_test, dataset_val
gc.collect()


print('\nTraining...', flush=True)  
print('######################################################################################', flush=True)   
train_model(device, model, train_loader, val_loader, criterion, optimizer, fig_dir, fig_pred_dir, LOSSweights, num_epochs)


# SAVE MODEL
torch.save(model.state_dict(), f"{model_dir}/{model_name}")
print("\nModel saved successfully!")


print('\nEvaluation ...', flush=True)  
print('######################################################################################', flush=True)

# Set model to evaluation mode
model.eval()

# Initialize metric trackers
total_correct = 0
total_samples = 0
total_squared_error = 0.0

# Initialize lists to store predictions, ground truths, and inputs for plotting
all_predictions = []
all_truths = []
all_inputs = []

# Disable gradient computation
with torch.no_grad():
    for inputs, reflectivity in test_loader:
        # Move data and target to the device
        inputs, reflectivity = inputs.to(device), reflectivity.to(device)
        
        # Forward pass
        outputs = model(inputs)

        # Collect data for plotting
        all_predictions.append(outputs.cpu())
        all_truths.append(reflectivity.cpu())
        all_inputs.append(inputs.cpu())
        


print('\nCompute prediction in the RADAR space ...', flush=True)  
print('######################################################################################', flush=True)

print(f'len all_inputs list after validation     -- > {len(all_inputs)}')
print(f'len all_prediction list after validation -- > {len(all_predictions)}')
print(f'len all_truth list after validation      -- > {len(all_truths)}')

print('\n@ Denormalization of all_predictions, all_truths, and all_inputs')

# Choose appropriate denormalization function for output (dbzh)
if output_normalization_type == 'log':
    output_denormalize_fn = log_denormalization
    denorm_min = dbzh_min
    denorm_max = dbzh_max
    output_min_suffix = 'log_min'
    output_max_suffix = 'log_max'
elif output_normalization_type == 'minmax':
    output_denormalize_fn = min_max_denormalization
    denorm_min = dbzh_min
    denorm_max = dbzh_max
    output_min_suffix = 'min'
    output_max_suffix = 'max'
else:
    raise ValueError(f"Unsupported output normalization type: {output_normalization_type}")

# Inputs are always min-max normalized
input_denormalize_fn = min_max_denormalization
input_min_dict = torch.load(f'{data_dir}/variables_min.pth')
input_max_dict = torch.load(f'{data_dir}/variables_max.pth')

# Denormalize predictions and truths (output)
all_predictions = [output_denormalize_fn(pred, denorm_min, denorm_max) for pred in all_predictions]
all_truths      = [output_denormalize_fn(truth, denorm_min, denorm_max) for truth in all_truths]

print(f'len all_prediction       list after denormalization     -- > {len(all_predictions)}')
print(f'len all_truth            list after denormalization     -- > {len(all_truths)}')

# Create dictionary to store per-variable denormalized input
all_inputs_dict = {f'all_inputs_{var}': [] for var in variables.keys()}

# Loop through each input sample and denormalize each input variable using min-max
for i, inp in enumerate(all_inputs):
    for key_var in variables.keys():
        idx = list(variables.keys()).index(key_var)
        min_val = input_min_dict[f'{key_var}_min']
        max_val = input_max_dict[f'{key_var}_max']
        denorm_var = input_denormalize_fn(inp[:, idx, :, :], min_val, max_val)
        all_inputs_dict[f'all_inputs_{key_var}'].append(denorm_var.cpu())

# Print stats
for key in all_inputs_dict.keys():
    print(f'len {str(key).ljust(20)} list after denormalization     -- > {len(all_inputs_dict[key])}')


print('\n@ Concatenate batches all_predictions, all_truths, and all_inputs')
all_predictions = torch.cat(all_predictions)  
all_truths      = torch.cat(all_truths)  
print(f'all_prediction       shape -- > {all_predictions.shape}')
print(f'all_truth            shape -- > {all_truths.shape}')  
for key in all_inputs_dict.keys():
    all_inputs_dict[key] = torch.cat(all_inputs_dict[key])   
    print(f'{str(key).ljust(20)} shape -- > {all_inputs_dict[key].shape}' )

  
print('\n@ Remove channel dimension of all_predictions, all_truths, and all_inputs for plotting')

# Remove the channel dimension (1) for plotting
all_predictions = all_predictions.squeeze(1)  
all_truths      = all_truths.squeeze(1) 

# Saving all_prediction & all_truth
np.save(f'{data_dir}/all_predictions.npy', all_predictions)
np.save(f'{data_dir}/all_truths.npy', all_truths)

print(f'all_prediction       shape -- > {all_predictions.shape}')
print(f'all_truth            shape -- > {all_truths.shape}')
for key in all_inputs_dict.keys():
    all_inputs_dict[key] = all_inputs_dict[key].squeeze(1)
    print(f'{str(key).ljust(20)} shape -- > {all_inputs_dict[key].shape}' )


print('\n@ Compute metrics between prediction and target for each timestep: RME and SSIM')

BIAS_values = []
RMSE_values = []
SSIM_values = []

for t in range(len(time_val)):
    # Create a mask where truths are not equal to 0
    mask       = all_truths.cpu().numpy()[t,:,:] != 0
    bias       = np.mean(all_predictions.cpu().numpy()[t,:,:][mask] - all_truths.cpu().numpy()[t,:,:][mask])
    rmse       = np.sqrt(np.mean((all_predictions.cpu().numpy()[t,:,:] - all_truths.cpu().numpy()[t,:,:])**2))
    ssim_value = ssim(all_predictions.cpu().numpy()[t], all_truths.cpu().numpy()[t], data_range=all_truths[t].cpu().numpy().max() - all_truths.cpu().numpy()[t].min())

    BIAS_values.append(bias)
    RMSE_values.append(rmse)
    SSIM_values.append(ssim_value)

BIAS_values = np.array(BIAS_values) 
RMSE_values = np.array(RMSE_values)
SSIM_values = np.array(SSIM_values) 

print(BIAS_values.shape, RMSE_values.shape)

np.save(f'{data_dir}/BIAS.npy', BIAS_values)
np.save(f'{data_dir}/RMSE.npy', RMSE_values)
np.save(f'{data_dir}/SSIM.npy', SSIM_values)
np.save(f'{data_dir}/TIME_TEST.npy', time_test)

avg_BIAS = sum(BIAS_values) / len(BIAS_values)
avg_RMSE = sum(RMSE_values) / len(RMSE_values)

print(f'The best in RMSE and SSIM are (idx) : ({np.argmin(RMSE_values)},{np.argmax(SSIM_values)})')
print(f'The worst in RMSE and SSIM are (idx): ({np.argmax(RMSE_values)},{np.argmin(SSIM_values)})')
print(f'Average BIAS is: {avg_BIAS}')
print(f'Average RMSE is: {avg_RMSE}')


print('\n@ Select samples to plot')
n_samples_to_plot  = len(all_predictions)
sample_predictions = all_predictions[:n_samples_to_plot]  
sample_predictions_filtered = sample_predictions.clone()
sample_predictions_filtered[sample_predictions_filtered < 13.5] = 0
sample_truths      = all_truths[:n_samples_to_plot]   
#sample_truths[sample_truths != 0] -= num_to_add
print(f'sample_predictions   shape -- > {sample_predictions.shape}')
print(f'sample_truths        shape -- > {sample_truths.shape}')

samples_dict = {}

for key, all_inputs_key in zip(variables.keys(), all_inputs_dict.keys()):
    samples_dict[f'sample_inputs_{key}'] = all_inputs_dict[all_inputs_key][:n_samples_to_plot]
    
for key in samples_dict.keys():   
    print(f'{str(key).ljust(20)} shape -- > {samples_dict[key].shape}')
    
 
print('\n@ Plotting ...', flush=True)  
print('######################################################################################', flush=True)
print(f'Images saved in:\n{fig_pred_dir}', flush=True)
# Plot best and worst
for i in [np.argmin(RMSE_values),np.argmax(RMSE_values),np.argmax(SSIM_values),np.argmin(SSIM_values)]:
    plt.figure(figsize=(12, 23))
    # Determine color limits from ground truth
    vmin, vmax = sample_truths[i].min(), sample_truths[i].max()

    for idx, key in enumerate(variables.keys()):
        plt.subplot(8, 4, idx + 1)
        plt.imshow(samples_dict[f'sample_inputs_{key}'][i,:,:],cmap='gist_ncar_r')
        plt.title(f'Input ({key})')
        plt.colorbar(location='bottom')
        plt.gca().invert_yaxis()
    
    # Plot the ground truth (reflectivity)
    plt.subplot(8, 4, 25)
    plt.imshow(sample_truths[i], cmap='gist_ncar_r', vmin=vmin, vmax=vmax)
    plt.title('Ground Truth (Reflectivity)')
    plt.colorbar(location='bottom')
    plt.gca().invert_yaxis()

    
    # Plot the model's prediction
    plt.subplot(8, 4, 26)
    plt.imshow(sample_predictions[i], cmap='gist_ncar_r', vmin=vmin, vmax=vmax)
    plt.title('Prediction')
    plt.colorbar(location='bottom')
    plt.gca().invert_yaxis()


    # Plot the model's prediction bias corrected
    plt.subplot(8, 4, 27)
    plt.imshow(sample_predictions_filtered[i], cmap='gist_ncar_r', vmin=vmin, vmax=vmax)
    plt.title('Prediction >= 13.5')
    plt.colorbar(location='bottom')  
    plt.gca().invert_yaxis()
    
    # Plot the model's prediction
    plt.subplot(8, 4, 29)
    plt.scatter(sample_predictions[i], sample_truths[i], color='r', label="Prediction vs Truth")
    plt.plot(sample_truths[i], sample_truths[i], color='k', linestyle='--', label="y = x")
    
    min_val = torch.min(torch.min(sample_predictions[i]), torch.min(sample_truths[i]))
    max_val = torch.max(torch.max(sample_predictions[i]), torch.max(sample_truths[i]))

    # Convert these to scalars 
    min_val = min_val.item()
    max_val = max_val.item()
    
    # Set tick intervals
    plt.xticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the x-axis
    plt.yticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the y-axis   
    
    # Set the same limits for both axes
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)    
    
    # Add labels and title
    plt.xlabel("Predicted Reflectivity")
    plt.ylabel("True Reflectivity")
    plt.title("Prediction vs Ground Truth")

    # Set aspect ratio to equal
    plt.gca().set_aspect('equal', adjustable='box')

    # Plot the model's prediction
    plt.subplot(8, 4, 30)
    plt.scatter(sample_predictions_filtered[i], sample_truths[i], color='r', label="Prediction >= 13.5 vs Truth")
    plt.plot(sample_truths[i], sample_truths[i], color='k', linestyle='--', label="y = x")
    
    min_val = torch.min(torch.min(sample_predictions[i]), torch.min(sample_truths[i]))
    max_val = torch.max(torch.max(sample_predictions[i]), torch.max(sample_truths[i]))

    # Convert these to scalars 
    min_val = min_val.item()
    max_val = max_val.item()
    
    # Set tick intervals
    plt.xticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the x-axis
    plt.yticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the y-axis   
    
    # Set the same limits for both axes
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)    
    
    # Add labels and title
    plt.xlabel("Predicted Reflectivity")
    plt.ylabel("True Reflectivity")
    plt.title("Prediction vs Ground Truth")

    # Set aspect ratio to equal
    plt.gca().set_aspect('equal', adjustable='box')
    
    plt.suptitle(f'{time_test[i]}')
    plt.tight_layout()
    plt.savefig(f'{fig_pred_dir}/0_best_worst_{num_epochs}_{i}')
    plt.close()

###############################################################################
# Plot each sample
if len(sample_predictions) > 99:
    num_sample_plot = 100
else:
    num_sample_plot = len(sample_predictions)
    
for i in range(0,num_sample_plot):
    plt.figure(figsize=(12, 23))
    # Determine color limits from ground truth
    vmin, vmax = sample_truths[i].min(), sample_truths[i].max()

    for idx, key in enumerate(variables.keys()):
        plt.subplot(8, 4, idx + 1)
        plt.imshow(samples_dict[f'sample_inputs_{key}'][i,:,:],cmap='gist_ncar_r')
        plt.title(f'Input ({key})')
        plt.colorbar(location='bottom')
        plt.gca().invert_yaxis()
    
    # Plot the ground truth (reflectivity)
    plt.subplot(8, 4, 25)
    plt.imshow(sample_truths[i], cmap='gist_ncar_r', vmin=vmin, vmax=vmax)
    plt.title('Ground Truth')
    plt.colorbar(location='bottom')
    plt.gca().invert_yaxis()

    
    # Plot the model's prediction
    plt.subplot(8, 4, 26)
    plt.imshow(sample_predictions[i], cmap='gist_ncar_r', vmin=vmin, vmax=vmax)
    plt.title('Prediction')
    plt.colorbar(location='bottom')
    plt.gca().invert_yaxis()


    # Plot the model's prediction
    plt.subplot(8, 4, 27)
    plt.imshow(sample_predictions[i]-sample_truths[i], cmap='seismic', vmin=-20, vmax=20)
    plt.title('Prediction - Truth')
    plt.colorbar(location='bottom')  
    plt.gca().invert_yaxis()

    # Plot the model's prediction
    plt.subplot(8, 4, 28)
    plt.imshow(sample_predictions_filtered[i], cmap='gist_ncar_r', vmin=vmin, vmax=vmax)
    plt.title('Prediction >= 13.5')
    plt.colorbar(location='bottom')  
    plt.gca().invert_yaxis()
    
    # Plot the model's prediction
    plt.subplot(8, 4, 29)
    plt.scatter(sample_predictions[i], sample_truths[i], color='r', label="Prediction vs Truth")
    plt.plot(sample_truths[i], sample_truths[i], color='k', linestyle='--', label="y = x")
    
    min_val = torch.min(torch.min(sample_predictions[i]), torch.min(sample_truths[i]))
    max_val = torch.max(torch.max(sample_predictions[i]), torch.max(sample_truths[i]))

    # Convert these to scalars 
    min_val = min_val.item()
    max_val = max_val.item()
    
    # Set tick intervals
    plt.xticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the x-axis
    plt.yticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the y-axis   
    
    # Set the same limits for both axes
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)    
    
    # Add labels and title
    plt.xlabel("Predicted Reflectivity")
    plt.ylabel("True Reflectivity")
    plt.title("Prediction vs Ground Truth")

    # Set aspect ratio to equal
    plt.gca().set_aspect('equal', adjustable='box')

    # Plot the model's prediction
    plt.subplot(8, 4, 30)
    plt.scatter(sample_predictions_filtered[i], sample_truths[i], color='r', label="Prediction >= 13.5 vs Truth")
    plt.plot(sample_truths[i], sample_truths[i], color='k', linestyle='--', label="y = x")
    
    min_val = torch.min(torch.min(sample_predictions[i]), torch.min(sample_truths[i]))
    max_val = torch.max(torch.max(sample_predictions[i]), torch.max(sample_truths[i]))

    # Convert these to scalars 
    min_val = min_val.item()
    max_val = max_val.item()
    
    # Set tick intervals
    plt.xticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the x-axis
    plt.yticks(np.arange(min_val, max_val, 10))  # Tick step of 10 on the y-axis   
    
    # Set the same limits for both axes
    plt.xlim(min_val, max_val)
    plt.ylim(min_val, max_val)    
    
    # Add labels and title
    plt.xlabel("Predicted Reflectivity")
    plt.ylabel("True Reflectivity")
    plt.title("Prediction vs Ground Truth")

    # Set aspect ratio to equal
    plt.gca().set_aspect('equal', adjustable='box')
    
    plt.suptitle(f'{time_test[i]}', x=0.5, y=0.99)
    plt.tight_layout()
    plt.savefig(f'{fig_pred_dir}/Sample_{str(i + 1).zfill(4)}_{num_epochs}_{i}')
    plt.close()

print('\nDONE!!!')

