import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
import os
import torch
import random
import numpy as np

from ResUNet import ResUNet
from Utils import create_dir, count_parameters_by_layer
from WeightedMSELoss import WeightedMSELoss
from EdgeAwareWeightedLoss import EdgeAwareWeightedLoss


seed = 42
# Set random seed
torch.manual_seed(seed)  # for PyTorch CPU operations
torch.cuda.manual_seed(seed)  # for GPU operations (if using CUDA)
torch.cuda.manual_seed_all(seed)  # if using multiple GPUs
np.random.seed(seed)  # for NumPy
random.seed(seed)  # for Python's built-in random module
# PyTorch provides a torch.backends.cudnn module to control deterministic behavior for CUDA operations.
#deterministic = True: Ensures deterministic results by forcing PyTorch to use deterministic algorithms.
#benchmark = False: Prevents dynamic selection of algorithms based on performance heuristics (which may introduce randomness).
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set up PyTorch generator for deterministic shuffling
g = torch.Generator()
g.manual_seed(seed)  # Set the global seed for PyTorch



# Define worker initialization function for DataLoader
def seed_worker(worker_id):
    print(f'Print torch ini seed  in seed worker {torch.initial_seed()}')
    worker_seed = torch.initial_seed() % 2**32  # Get a deterministic seed for the worker
    np.random.seed(worker_seed)  # Set NumPy's RNG seed
    random.seed(worker_seed)  # Set Python's random module seed
    torch.manual_seed(worker_seed)  # Set seed for CPU operations
    
    # If CUDA is available, set the seed for CUDA too
    if torch.cuda.is_available():
        torch.cuda.manual_seed(worker_seed)  # Set seed for GPU (single GPU)
        torch.cuda.manual_seed_all(worker_seed)  # Set seed for all GPUs (multi-GPU setup)

# Manually initialize the weights to a specific value
def init_weights(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# Set the number of cpu to be used
os.environ["PYTHONHASHSEED"] = str(seed)

#Control multi-threading behavior 
os.environ["OMP_NUM_THREADS"]   = "15" 
os.environ["MKL_NUM_THREADS"]   = "15"
os.environ["TORCH_NUM_THREADS"] = "15"

print(os.environ["OMP_NUM_THREADS"],  flush=True)
print(os.environ["MKL_NUM_THREADS"],  flush=True)
print(os.environ["TORCH_NUM_THREADS"],flush=True)

# DATA
shuffle = True
save_datasets_norm = True
save_datasets      = False

# Use this number to add a quantity where dbzh != 0. Use 0 to leave as dataset.
#num_to_add = 100

hours_selection = True
years_selection = True

desired_years = [2019, 2020, 2021, 2022]


desired_hours = [0, 6, 12, 18]
# M1 (Model 1)
#desired_hours = [1, 7, 13, 19]

# M2 (Model 2)
#desired_hours = [2, 8, 14, 20]

# M3 (Model 3)
#desired_hours = [3, 9, 15, 21]

# M4 (Model 4)
#desired_hours = [4, 10, 16, 22]

# M5 (Model 5)
#desired_hours = [5, 11, 17, 23]

var_names = ['t','v','u','r','r2m','t2m','msl']


qLevels = 0
tLevels = 4
vLevels = 4
uLevels = 4
rLevels = 4

output_normalization_type = 'minmax'  # 'log' or 'minmax'


if tLevels != 0:
    tcount = tLevels - 1
else:
    tcount = 0

if qLevels != 0:
    qcount = qLevels - 1
else:
    qcount = 0

if vLevels != 0:
    vcount = vLevels - 1
else:
    vcount = 0

if uLevels != 0:
    ucount = uLevels - 1
else:
    ucount = 0

if rLevels != 0:
    rcount = rLevels - 1
else:
    rcount = 0

in_channels = len(var_names) + tcount + qcount + vcount + ucount + rcount

aladin         = '../../ALADIN_2019_2023_radar_time_aligned.nc'
radar          = '../../RADAR_2019_2023_INTERPOLATED_160km.nc'    

# DIRECTORIES
out_dir        = 'TEST'
fig_dir        = f'{out_dir}/FIGURES'
fig_pred_dir   = f'{fig_dir}/PREDICTIONS'
data_dir       = f'{out_dir}/DATA'
model_dir      = f'{out_dir}/MODELS'
model_name     = f'UNet_min_max_norm_ConfigV1.pth'

create_dir(fig_pred_dir)
create_dir(data_dir)
create_dir(model_dir)

# HIPERPARAMETERS
# Initialize the distributed environment
model                   = ResUNet(in_channels)

model.apply(init_weights)

batch_size              = 16
lr                      = 1e-3
num_epochs              = 100
optimizer               = optim.Adam(model.parameters(), lr=lr)

criterion = EdgeAwareWeightedLoss(
    λ_l1=1.0, 
    λ_grad=0.7, 
    λ_ssim=0.3,
    λ_wmse=0.1,          
    wmse_alpha=4.5,
    wmse_base=1.0,
    wmse_max_weight=20.0,
    use_intensity_weight=False,  
    weight_normalize=True
)


early_stopping_patience = 10
min_delta               = 1e-4
radius                  = 20 # Radius for weight function
# FILTER PRECIPITATION TIMESTEP
dbz_threshold            = 13.5 
precipitation_percentage = 0 

# COUNTING MODEL PARAMETERS
layer_params, total_params = count_parameters_by_layer(model)

print(f'Number of model params --> {total_params}')
# Separate layer names and parameter counts for plotting
layer_names = list(layer_params.keys())
param_counts = list(layer_params.values())

# Plot the number of parameters per layer
plt.figure(figsize=(15, 10))
plt.barh(layer_names, param_counts, color='skyblue')
plt.xlabel("Number of Parameters")
plt.ylabel("Layer Name")
plt.title("Number of Parameters per Layer in Model")
plt.gca().invert_yaxis()  # Optional: invert y-axis for easier reading
plt.tight_layout()
plt.savefig(f'{fig_dir}/Parameter_per_layer_in_model.png')
plt.close()

print(f'Look at {fig_dir}/Parameter_per_layer_in_model.png to see parameters structure.', flush=True)
