#! /usr/bin/env python3

from osgeo import gdal, osr, gdalconst
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
import pyproj as pyp
import pandas as pd

def get_extent_npx(ds):
    
    upx, xres, xskew, upy, yskew, yres = ds.GetGeoTransform()
    
    cols = ds.RasterXSize
    rows = ds.RasterYSize

    ulx = upx + 0*xres + 0*xskew
    uly = upy + 0*yskew + 0*yres
    
    llx = upx + 0*xres + rows*xskew
    lly = upy + 0*yskew + rows*yres
    
    lrx = upx + cols*xres + rows*xskew
    lry = upy + cols*yskew + rows*yres
    
    urx = upx + cols*xres + 0*xskew
    ury = upy + cols*yskew + 0*yres
    
    return (llx, lly, urx, ury), cols, rows, xres, yres

def get_coordinate_vector(extent, cols, rows, xres, yres):
    x = np.linspace(extent[0] + np.abs(xres/2), extent[2] - np.abs(xres/2), cols)
    y = np.linspace(extent[1] + np.abs(yres/2), extent[3] - np.abs(yres/2), rows)
    
    return x, y

def warp_to_target_grid(sds, tds, out_name, method="Average", outputType = gdalconst.GDT_Float32):
    #source SRS
    s_srs = osr.SpatialReference()
    s_srs.ImportFromWkt(sds.GetProjection())
    
    #target SRS
    t_srs = osr.SpatialReference()
    t_srs.ImportFromWkt(tds.GetProjection())
    
    #target extent and nr of pixel
    extent, nx, ny, xres, yres = get_extent_npx(tds)
    
    projected = gdal.Warp(out_name, sds, srcSRS = s_srs, dstSRS = t_srs, outputBounds = extent, width = nx, height = ny, resampleAlg = method, outputType = outputType)
    
    return projected
    
    
    

# source
s_file = '/path/to/metstor_nfs/home/david/Projects/SECURES/Data_Power/GIS processing files von Ricki/usable-area+buffer301m.tif'
sds = gdal.Open(s_file, gdal.GA_ReadOnly)
band = sds.GetRasterBand(1)
band.SetNoDataValue(7)
#sds.FlushCache()

# target grid
t_file = '/path/to/wind/wind-2019_01-08-150mPRE-avg.grb'
tds = gdal.Open(t_file, gdal.GA_ReadOnly)


# output
out_file = "/path/to/wind_masks_projected/" + os.path.basename(s_file).replace(".tif","_COSMO6.tif")
pds = warp_to_target_grid(sds, tds, out_file)


## project target grid to lat lon

p_srs = osr.SpatialReference()
p_srs.ImportFromWkt(pds.GetProjection())

p_crs = pyp.CRS.from_proj4(p_srs.ExportToProj4())
wgs84 = "epsg:4326"
transformer = pyp.Transformer.from_crs(p_crs, wgs84)

x,y = get_coordinate_vector(*get_extent_npx(pds))
X,Y = np.meshgrid(x,y)

LAT, LON = transformer.transform(X, Y)

df = pd.DataFrame(data={"LON":LON.ravel(), "LAT":LAT.ravel()})
df.to_csv("/path/to/wind_mask_land_LAT_LON.csv")

# set some values to 0
p_band = pds.GetRasterBand(1)
pds_data = p_band.ReadAsArray()
pds_data[(LON[::-1,:] < -31) | (LON[::-1,:] > 62.7)] = 0
p_band.WriteArray(pds_data)
#out_file2 = out_file.replace(".tif","_corrected.tif")
pds.FlushCache()


