#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import xarray as xr
import glob
import numpy as np

def estimate_power_per_discharge(df_in, pp_data, n_iter=3, scaling_df = None):
    pp_id = df_in.columns.get_level_values("pp_id")
    avg_power_MW = pp_data.loc[pp_id, "avg_annual_generation_GWh"]/8760*1000
    max_capacity = pp_data.loc[pp_id, "installed_capacity_MW"]
    
    ts_tmp = df_in.copy(deep=True)
    
    if scaling_df is None:
        calc_scaling = True
        scaling_df = pd.DataFrame(index = pp_id, columns = range(n_iter))
    else:
        calc_scaling = False
        
    for ii in range(n_iter):
        
        if calc_scaling:
            scaling_tmp = avg_power_MW/ts_tmp.mean()
            scaling_df.loc[:,ii] = scaling_tmp

        scaling = scaling_df.loc[:,ii]
            
        ts_tmp = ts_tmp*scaling
        ts_tmp = ts_tmp.where(ts_tmp<=max_capacity, max_capacity, axis=1)
    
    return ts_tmp, scaling_df
    
def analytic_powercurve(discharge, power):
    P_max = power.max()
    P_min = power.min()
    
    P = power[(power>P_min) & (power<P_max)].values
    Q = discharge[(power>P_min) & (power<P_max)].values
    
    k,d = np.polyfit(Q[np.argsort(P)],P[np.argsort(P)],1)
                       
    return k,d,P_min,P_max

def apply_analytic_powercurve(Q, k, d, P_min, P_max):
    P_out = Q*k + d
    P_out[P_out < P_min] = P_min
    P_out[P_out > P_max] = P_max
    
    return P_out
    
    
    

# powerplants
pp_db = "/path/to/Data_Power/jrc-hydro-power-plant-database_geometry_est_ann_power.xlsx"
pp_df = pd.read_excel(pp_db)
pp_gdf = gpd.GeoDataFrame(pp_df, geometry=gpd.points_from_xy(pp_df.lon, pp_df.lat))
pp_gdf.set_crs(epsg=4326, inplace=True)


# filter Powerplants
pp_gdf_filt = pp_gdf[(pp_gdf["type"] == 'HROR') & 
                     (pp_gdf["avg_annual_generation_GWh"].notnull()) & 
                     (pp_gdf["installed_capacity_MW"].notnull())
                     ]

pp_gdf_filt = pp_gdf_filt[pp_gdf_filt["avg_annual_generation_GWh"]/8.76 <= pp_gdf_filt["installed_capacity_MW"]]
pp_gdf_filt.set_index("id", inplace=True)


Q_all = pd.read_pickle("/path/to/Data_Ehype/eHYPE_discharge_model-mean_historical.pkl")
Q_all = Q_all.loc[:,pp_gdf_filt.index]
Q_all.columns = Q_all.columns.get_level_values(0)

daily_ror0, scaling = estimate_power_per_discharge(Q_all, pp_gdf_filt, n_iter = 10)

df_apc = pd.DataFrame(index = daily_ror0.columns, columns = ['k','d','P_min', 'P_max'])
for pp in daily_ror0.columns:
    df_apc.loc[pp,:] = analytic_powercurve(Q_all[pp],daily_ror0[pp])

    
Q_all_rcp45 = pd.read_pickle("/path/to/Data_Ehype/eHYPE_discharge_model-mean_rcp45.pkl")
Q_all_rcp85 = pd.read_pickle("/path/to/Data_Ehype/eHYPE_discharge_model-mean_rcp85.pkl")
Q_all_scen = pd.concat([Q_all_rcp45, Q_all_rcp85], axis=1)
Q_all_scen = Q_all_scen.loc[:,pp_gdf_filt.index]
Q_all_scen.columns = Q_all_scen.columns.droplevel(1)

daily_ror_scen = Q_all_scen.apply(lambda Q, df_apc: apply_analytic_powercurve(Q, *df_apc.loc[Q.name[0],:].values), args = (df_apc,))

#%%
ann_mean_total_power = daily_ror0.resample("A").mean().sum(axis=1)
ann_mean_total_power.name="historical"
ann_mean_total_power_scen = daily_ror_scen.groupby(axis=1, level=1).sum().resample("A").mean()

ann_mean_total_power_hist_scen = ann_mean_total_power_scen.join(ann_mean_total_power.astype(np.float32), how="outer")
ann_mean_total_power_hist_scen["rcp85"].fillna(ann_mean_total_power_hist_scen["historical"],inplace=True)
ann_mean_total_power_hist_scen["rcp45"].fillna(ann_mean_total_power_hist_scen["historical"],inplace=True)
ann_mean_rel_power_hist_scen = ann_mean_total_power_hist_scen/ann_mean_total_power.mean()


ann_mean_rel_power = ann_mean_total_power/ann_mean_total_power.mean()
ann_mean_rel_power_scen = ann_mean_total_power_scen/ann_mean_total_power.mean()

