import argparse
import torch
import os
import numpy as np
import xarray as xr
import yaml
from tqdm import tqdm
import pandas as pd
from datetime import datetime
import time
import cv2
import rasterio
from torch.utils.data import Dataset
from scipy.ndimage import zoom
import sys
import onnxruntime as ort

RESET_SEQ = "\033[0m"
RED_SEQ = "\033[31m"
GREEN_SEQ = "\033[32m"
VAR_NAMES = ['z250', 'z500', 'z700', 'z850', 't250', 't500', 't700', 't850', 'u250', 'u500',
 'u700', 'u850', 'v250', 'v500', 'v700', 'v850', 't2m', 'u10', 'v10', 'msl', "tp"]

def get_DEM():
    
    file_path = "./utils/GENCO_world_resampled_0.25deg.tif"

    # 打开 .tif 文件
    with rasterio.open(file_path) as dataset:
        band1 = dataset.read(1)
        band1 = cv2.resize(band1,(1440,720))
        
        # 获取经度信息
        transform = dataset.transform
        width = dataset.width
        height = dataset.height
        pixel_size = transform[0]
        west = transform[2]
        east = west + width * pixel_size
        
        # 将数据从 -180~180 转换为 0~360
        split_index = int(width / 2)  # 数据从中间分割
        band1_shifted = np.hstack((band1[:, split_index:], band1[:, :split_index]))
        band1_shifted = (band1_shifted - band1_shifted.mean()) / band1_shifted.std()
    return band1_shifted

def time_to_features(timestamp, height, width):
    dt = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S")
    month = dt.month
    day = dt.day
    hour = dt.hour
    
    # 正余弦编码
    month_sin = np.sin(2 * np.pi * month / 12.0)
    month_cos = np.cos(2 * np.pi * month / 12.0)
    day_sin = np.sin(2 * np.pi * day / 31.0)
    day_cos = np.cos(2 * np.pi * day / 31.0)

    time_features = np.array(
        [month_sin, month_cos, day_sin, day_cos], 
        dtype=np.float32
    ).reshape(4, 1, 1)
    time_features = np.broadcast_to(time_features, (4, height, width))
    
    return time_features
 
def data_process_sub(data2d):

    ny, nx = data2d.shape

    F = np.fft.fft2(data2d)
    F_shifted = np.fft.fftshift(F)

    fx = np.fft.fftfreq(nx, d=17.6)
    fy = np.fft.fftfreq(ny, d=17.6)
    fx = np.fft.fftshift(fx)
    fy = np.fft.fftshift(fy)
    FX, FY = np.meshgrid(fx, fy)

    freq_radius = np.sqrt(FX**2 + FY**2) 
    wavelength = 1.0 / (freq_radius + 1e-10) 

    mask = wavelength >= 500
    F_filtered = F_shifted * mask

    F_filtered_shifted_back = np.fft.ifftshift(F_filtered)
    filtered_data = np.fft.ifft2(F_filtered_shifted_back).real

    return filtered_data

def data_process(data):

    if data.ndim == 2:
        return data_process_sub(data)
    
    elif data.ndim == 3:
        c, h, w = data.shape
        filtered_data = np.zeros_like(data)
        
        for i in range(c):
            filtered_data[i] = data_process_sub(data[i])
            
        return filtered_data

class zarrDataset(Dataset):
    def __init__(self, flag='infer', var_names = None):
        start_time = time.time() 
        self.flag = flag

        self.mean = np.load('./utils/z_mean_by_level.npy')
        self.std = np.load('./utils/z_std_by_level.npy')
        zarr_input = xr.open_zarr('./ec-earth-ssp126.20150101_20150131.c21.p4.h24')['__xarray_dataarray_variable__']
        self.global_dem = np.expand_dims(get_DEM().astype(np.float32), axis=0)

        self.var_index = np.array([VAR_NAMES.index(var) for var in var_names])
        self.input_dataset = zarr_input.astype(np.float32).sel(plev= var_names)

        self.time_list = self.input_dataset['time'].values
        self.dataset_len = len(self.time_list)

        end_time = time.time()
        execution_time = end_time - start_time

        print(GREEN_SEQ+"【INFO】"+RESET_SEQ+f"{flag} dataset has been loaded. (time cost: {execution_time}s)")

    def __getitem__(self, index):

        sample_id = None
        input = None
        label = None

        sample_id = str(np.datetime64(self.time_list[index], 's'))
        time_embed = time_to_features(sample_id,720,1440)
        static_data = np.concatenate([self.global_dem, time_embed], axis=0)

        input = self.input_dataset.isel(time=index).values
        input = zoom(input, (1, 720/256, 1440/512), order=0)
        input = data_process(input)

        sample = [input,static_data, sample_id]

        return sample

    
    def __len__(self):
        return self.dataset_len

def infer_onnx(args):
    os.makedirs(args.savePath, exist_ok=True)

    val_dataset = zarrDataset(flag="infer", var_names=["z250","z500","t2m","u10","v10","msl","tp"])
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                            batch_size=1,
                                            shuffle=False,
                                            pin_memory=True,
                                            num_workers=0,
                                            drop_last=False)

    MEAN_ = val_dataset.mean
    STD_ = val_dataset.std
    VAR_INDEX = val_dataset.var_index
    mean = MEAN_[VAR_INDEX]
    std = STD_[VAR_INDEX]

    print(f"Loading ONNX model: {args.onnx_path}")
    try:

        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        session = ort.InferenceSession(args.onnx_path, providers=providers)
        print("ONNX model loaded with GPU!")
    except Exception as e:
        print(f"Failed to load ONNX model with GPU: {e}")
        # 尝试使用CPU
        try:
            providers = ['CPUExecutionProvider']
            session = ort.InferenceSession(args.onnx_path, providers=providers)
            print("ONNX model loaded with CPU!")
        except Exception as e:
            print(f"Failed to load ONNX model with GPU: {e}")
            return


    input_names = [input.name for input in session.get_inputs()]
    output_names = [output.name for output in session.get_outputs()]
    print(f"ONNX model inputs: {input_names}")
    print(f"ONNX model outputs: {output_names}")

    latitude_coord = np.load("./utils/LAT_COORD.npy")
    longitude_coord = np.load("./utils/LON_COORD.npy")

    z_data = (['time', 'channel', 'lat', 'lon'], 
            np.empty((0, len(VAR_INDEX), len(latitude_coord), len(longitude_coord)), dtype=np.float32))
    times = np.array([], dtype='datetime64[ns]')

    empty_dataset = xr.Dataset(
        {
            "data": z_data
        },
        coords={
            'time': times,
            'channel': ["z250","z500","t2m","u10","v10","msl","tp"],
            'lat': latitude_coord,
            'lon': longitude_coord
        }
    )
    chunks = {
        'time': 1,          # 根据需求调整
        'channel': 1,
        'lat': len(latitude_coord),      # 根据需求调整
        'lon': len(longitude_coord)      # 根据需求调整
    }

    empty_dataset = empty_dataset.chunk(chunks)
    new_zarr = f"{args.savePath}"
    encoding = {
            "data": {
            'chunks': (1, 1, len(latitude_coord), len(longitude_coord)),  # 与chunks字典对应
        },
        'time': {
            'units': 'hours since 1900-01-01',  # 更早的基准日期
            'dtype': 'int64'
        }
    }
    empty_dataset.to_zarr(new_zarr, mode='w', encoding=encoding)

    for bidx, data in enumerate(tqdm(val_loader, desc=f"ONNX inferring")):
        inputs, static, sample_ids = data
        inputs, static = inputs.to("cuda"), static.to("cuda")
        dt = 1.0 / 10
        
        x_t = torch.randn(inputs.shape[0], 4, len(["z250","z500","t2m","u10","v10","msl","tp"]), 720, 1440).to("cuda")
        condition = torch.cat([inputs, static], axis=1)
        
        for j in range(10):
            t = torch.full((x_t.shape[0],), j * dt).to("cuda")
            
            # 准备ONNX输入
            onnx_inputs = {
                'x_t': x_t.cpu().numpy(),
                'condition': condition.cpu().numpy(),
                't': t.cpu().numpy()
            }
            
            v_pred = session.run(output_names, onnx_inputs)[0]
            v_pred = torch.from_numpy(v_pred).to("cuda")
        
            x_t = x_t + v_pred * dt
        
        data = x_t.cpu().numpy().reshape(-1, len(["z250","z500","t2m","u10","v10","msl","tp"]), len(latitude_coord), len(longitude_coord))
        data = data * std[None, :, None, None] + mean[None, :, None, None]
        
        start_time = pd.to_datetime(sample_ids[0])  
        end_time = pd.to_datetime(sample_ids[-1])  
        expanded_timestamps = pd.date_range(start=start_time.floor("D"),
                                          end=end_time.floor("D") + pd.Timedelta(days=1) - pd.Timedelta(hours=6), 
                                          freq="6h")
        sample_ids = expanded_timestamps.strftime("%Y-%m-%dT%H:%M:%S").tolist()
        sample_ids = np.array([np.datetime64(x).astype('datetime64[ns]') for x in sample_ids])

        sub_dataset = xr.Dataset(
        {
            "data": (['time', "channel", 'lat', 'lon'], data)
        },
        coords={
            'time': sample_ids,
            'channel': ["z250","z500","t2m","u10","v10","msl","tp"],
            'lat': latitude_coord,
            'lon': longitude_coord
            }
        )
        sub_dataset.to_zarr(new_zarr, mode='a', append_dim='time')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='ONNX Inference parameters')
    parser.add_argument('--onnx_path', type=str, help='Path to ONNX model')
    parser.add_argument('--savePath', type=str, help='Path to save results')
    args = parser.parse_args()
    
    infer_onnx(args)
