#! /usr/bin/env python3

import numpy as np
import xarray as xr
import glob
import pyproj as pyp
import xesmf as xe
from osgeo import gdal,ogr,osr
import os
from joblib import Parallel, delayed

def xy2llbounds(ncf):
    """
    add boundaries of the spatial coordiantes, i.e. the coordiantes of the corners of the gridboxes
    required for regridding
    

    Parameters
    ----------
    ncf : xarray.Dataframe


    Returns
    -------
    ncf2 : xarray.Dataframe
        copy of ncf, with added boundaries.

    """
    x = ncf.x.values
    y = ncf.y.values
    
    dx = ncf.x.diff(dim="x").mean().item()
    dy = ncf.y.diff(dim="y").mean().item()
    
    # print(x, dx)
    
    x_b = x[0] - np.sign(dx)*dx/2 + np.arange(len(x)+1) * dx
    y_b = y[0] - np.sign(dy)*dy/2 + np.arange(len(y)+1) * dy
    
    # transform from Autria Lambert to WGS84
    rea_proj4 = "+proj=ob_tran +o_proj=longlat +o_lat_p=39.25 +o_lon_p=0 +lon_0=18 +no_defs +R=6371229"
    crs_rea = pyp.CRS.from_proj4(rea_proj4)
    crs_wgs84 = pyp.CRS.from_epsg(4326)
    
    transformer = pyp.Transformer.from_crs(rea_proj4, crs_wgs84)
    
    X, Y = np.meshgrid(x, y)
    X_B, Y_B = np.meshgrid(x_b, y_b)
    
    #BUG??
    lat, lon = transformer.transform(X, Y)
    lat_b, lon_b = transformer.transform(X_B, Y_B)
    
    # add metadata
    ncf2 = ncf.assign_coords(
                             y_b=("y_b", y_b), 
                             x_b=("x_b", x_b),
                             lon=(["y","x"], lon),
                             lat=(["y","x"], lat),
                             lon_b=(["y_b","x_b"], lon_b),
                             lat_b=(["y_b","x_b"], lat_b),
                             )

    
    return ncf2


def llbounds(ncf):
    lon = ncf.lon.values
    lat = ncf.lat.values
    
    dlon = ncf.lon.diff(dim="lon").mean().item()
    dlat = ncf.lat.diff(dim="lat").mean().item()

    
    #print(lon, dlon)
    
    lon_b = lon[0] - np.sign(dlon)*dlon/2 + np.arange(len(lon)+1) * dlon
    lat_b = lat[0] - np.sign(dlat)*dlat/2 + np.arange(len(lat)+1) * dlat
    
    ncf2 = ncf.assign_coords(lon_b=(["lon_b"], lon_b),
                             lat_b=(["lat_b"], lat_b),
                             )
    
    return ncf2

def regrid_spatial(in_file, targetgrid_file, weights_file, out_name):
    """
    routine to regrid data to desired grid, utilizing ESMF patch

    Parameters
    ----------
    ds_in : xarray.Dataset
        data to be interpolated.
    targetgrid_file : str
        path to target grid.
    weights_file : str
        path to weights grid.

    Returns
    -------
    ds_out : xarray.Dataset
        regridded dataset.

    """
    
    #read out grid
    target_grid = xr.open_dataset(targetgrid_file)
    if "lon_b" not in target_grid.coords:
        target_grid = llbounds(target_grid)
    
    # add bounds and lat/lon to file
    print("Reading", in_file)
    ds_in = xr.open_dataset(in_file)
    ds_in = ds_in.rename_dims({"rlon":"x", "rlat":"y"})
    ds_in = ds_in.rename_vars({"rlon":"x", "rlat":"y"})
    ds_in = xy2llbounds(ds_in)
    #ds_in2 = ds_in2.drop("mask")

    # read/calculate regridder
    if os.path.isfile(weights_file):
        print("Calculating Weights")
        regridder = xe.Regridder(ds_in, target_grid, 'patch',  weights = weights_file, reuse_weights = True)
    else:
        regridder = xe.Regridder(ds_in, target_grid, 'patch',  filename = weights_file)
    
    # regrid
    print("Regridding", in_file)
    ds_out = regridder(ds_in, keep_attrs = True)
    
    ds_out.to_netcdf(out_name)
    
    ds_in.close()
    target_grid.close()
    ds_out.close()

    return 0

if __name__ == "__main__":
    in_files = sorted(glob.glob("/path/to/REA/converted/hourly/2D/WS_150/WS_150m.2D.*.nc4"))
    target_grid_file = '/path/to//metstor_nfs/projects/Secrues/Output/Wspd/ERA5_ERA5L_HourlMerged/OUTPUT/ERA5_Land_1981_Wspeed_hourly.nc'
    out_path = '/path/to//metstor_nfs/projects/Secrues/REA/converted/regridded_ERA5/'

    num_cores = 24

        
    # run fist file serial to avoid race condition with weigths
    regrid_spatial(in_files[0], 
                    target_grid_file, 
                    out_path + "COSMO-REA6_2_ERA5.weights", 
                    out_path + os.path.basename(in_files[0]).replace(".nc4","_atERA5.nc4"))

    # run rest parallel
    parallel_input = [(in_file, 
                    target_grid_file, 
                    out_path + "COSMO-REA6_2_ERA5.weights", 
                    out_path + os.path.basename(in_file).replace(".nc4","_atERA5.nc4")) 
                    for in_file in in_files[1:]]
    Parallel(n_jobs=num_cores)(delayed(regrid_spatial)(*args) for args in parallel_input)







