#! /usr/bin/env python3
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mpl
import glob
import geopandas as gpd

grid = gpd.read_file("/path/to/ERA5_SECURES_Grid.shp")

land = gpd.read_file("/path/to//wind_land_mask_era5_vector.shp").drop("cat", axis=1)
land.columns = ["land","geometry"]
offshore = gpd.read_file("/path/to/wind_offshore_mask_era5_vector.shp").drop("cat", axis=1)
offshore.columns = ["offshore","geometry"]

grid = gpd.sjoin(grid, land).drop("index_right", axis=1)
grid = gpd.sjoin(grid, offshore).drop("index_right", axis=1)


# NUTS2
nuts2 =  gpd.read_file("/path/to/NUTS2_regions.shp")

pop_weights_path = '/path/to/popweight/*_isopop_era5.nc'
pop_weights = sorted(glob.glob(pop_weights_path))

pop_weights_df = pd.DataFrame(columns=["pop_density_sum", "n_pixel", "land", "offshore"])
for pw in pop_weights:
    print(pw)
    nuts = pw.split("/")[-1].split("_")[0]
    if(len(nuts)>4):
        continue
    with xr.open_dataset(pw) as pop_ds:
        pw_sum = pop_ds.popden.sum().item()
        n_pixel = pop_ds.fmask.sum().item()
        
        ll = pop_ds.fmask.to_dataframe().dropna()
        ll.reset_index(inplace=True)
        ll = gpd.GeoDataFrame(geometry=gpd.points_from_xy(ll.lon,ll.lat))
        ll = ll.set_crs(epsg="4326")
        ll2 = gpd.sjoin(grid,ll, "inner")
        land_sum = ll2.land.sum().item()
        offshore_sum = ll2.offshore.sum().item()
        
        
        pop_weights_df.loc[nuts,["pop_density_sum", "n_pixel", "land", "offshore"]] = [pw_sum, n_pixel, land_sum, offshore_sum]
        

pop_weights_df.to_csv("/path/to/output/NUTS2_weights.csv")



# NUTS3

pop_weights_path = '/path/to/NUTS2_era5/Total_pop_era5/AT???_isopop_era5.nc'
pop_weights = sorted(glob.glob(pop_weights_path))


pop_weights_df = pd.DataFrame(columns=["pop_density_sum", "n_pixel", "land", "offshore"])
for pw in pop_weights:
    print(pw)
    nuts = pw.split("/")[-1].split("_")[0]
    if(len(nuts)<5):
        continue
    with xr.open_dataset(pw) as pop_ds:
        pw_sum = pop_ds.popden.sum().item()
        n_pixel = pop_ds.fmask.sum().item()
        
        ll = pop_ds.fmask.to_dataframe().dropna()
        ll.reset_index(inplace=True)
        ll = gpd.GeoDataFrame(geometry=gpd.points_from_xy(ll.lon,ll.lat))
        ll = ll.set_crs(epsg="4326")
        ll2 = gpd.sjoin(grid,ll, "inner")
        land_sum = ll2.land.sum().item()
        offshore_sum = ll2.offshore.sum().item()
        
        
        pop_weights_df.loc[nuts,["pop_density_sum", "n_pixel", "land", "offshore"]] = [pw_sum, n_pixel, land_sum, offshore_sum]
        

pop_weights_df.to_csv("/path/to/NUTS3_weights.csv")


# EEZ      
eez_mask_path = "/path/to/Offshore/MRGID_*_at_era5.nc"
eez_masks = sorted(glob.glob(eez_mask_path))

eez_df = pd.DataFrame(columns=["offshore"])
for pw in eez_masks:
    print(pw)
    with xr.open_dataset(pw) as pop_ds:
        nuts = pw.split("/")[-1].split("_")[1]
        #if(len(nuts)>4):
            #continue
        
        # mask is 0, nan is no mask
        n_pixel = (pop_ds.fmasK + 1).sum().item()
        
        ll = pop_ds.fmasK.to_dataframe().dropna()
        ll.reset_index(inplace=True)
        ll = gpd.GeoDataFrame(geometry=gpd.points_from_xy(ll.lon,ll.lat))
        ll = ll.set_crs(epsg="4326")
        ll2 = gpd.sjoin(grid,ll, "inner")
        land_sum = ll2.land.sum().item()
        os_sum = ll2.offshore.sum().item()
        
        
        eez_df.loc[nuts,['n_pixel', "offshore", 'land']] = [n_pixel, os_sum, land_sum]
        
eez_df.to_csv("/path/to/output/EEZ_weights.csv")



