#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import xarray as xr
import glob
import numpy as np

def estimate_power_per_discharge(df_in, pp_data, n_iter=3, scaling_df = None):
    pp_id = df_in.columns.get_level_values("pp_id")
    avg_power_MW = pp_data.loc[pp_id, "avg_annual_generation_GWh"]/8760*1000
    max_capacity = pp_data.loc[pp_id, "installed_capacity_MW"]
    
    ts_tmp = df_in.copy(deep=True)
    
    if scaling_df is None:
        calc_scaling = True
        scaling_df = pd.DataFrame(index = pp_id, columns = range(n_iter))
    else:
        calc_scaling = False
        
    for ii in range(n_iter):
        
        if calc_scaling:
            scaling_tmp = avg_power_MW/ts_tmp.mean()
            scaling_df.loc[:,ii] = scaling_tmp

        scaling = scaling_df.loc[:,ii]
            
        ts_tmp = ts_tmp*scaling
        ts_tmp = ts_tmp.where(ts_tmp<=max_capacity, max_capacity, axis=1)
    
    return ts_tmp, scaling_df
    
def analytic_powercurve(discharge, power):
    P_max = power.max()
    P_min = power.min()
    
    P = power[(power>P_min) & (power<P_max)].values
    Q = discharge[(power>P_min) & (power<P_max)].values
    
    k,d = np.polyfit(Q[np.argsort(P)],P[np.argsort(P)],1)
                       
    return k,d,P_min,P_max

def apply_analytic_powercurve(Q, k, d, P_min, P_max):
    P_out = Q*k + d
    P_out[P_out < P_min] = P_min
    P_out[P_out > P_max] = P_max
    
    return P_out
    
    
    

# powerplants
pp_db = "/path/to/Data_Power/jrc-hydro-power-plant-database_geometry_est_ann_power.xlsx"
pp_df = pd.read_excel(pp_db)
pp_gdf = gpd.GeoDataFrame(pp_df, geometry=gpd.points_from_xy(pp_df.lon, pp_df.lat))
pp_gdf.set_crs(epsg=4326, inplace=True)

# filter Powerplants
pp_gdf_filt = pp_gdf[(pp_gdf["type"] != 'HROR') & 
                     (pp_gdf["avg_annual_generation_GWh"].notnull()) & 
                     (pp_gdf["installed_capacity_MW"].notnull())
                     ]

pp_gdf_filt = pp_gdf_filt[pp_gdf_filt["avg_annual_generation_GWh"]/8.76 <= pp_gdf_filt["installed_capacity_MW"]]
pp_gdf_filt.set_index("id", inplace=True)


Q_all = pd.read_pickle("/path/to/Data_Ehype/eHYPE_discharge_model-mean_historical.pkl")

cmt_avail = sorted(set(pp_gdf_filt.index) & set(Q_all.columns.get_level_values(0)))
Q_all = Q_all.loc[:,cmt_avail]
Q_all.columns = Q_all.columns.get_level_values(0)

scaling = (pp_gdf_filt["avg_annual_generation_GWh"]/8760*1000)/Q_all.mean()
scaling.dropna(inplace=True)
#%%
daily_power_reservoir = Q_all*scaling
    
Q_all_rcp45 = pd.read_pickle("/path/to/Data_Ehype//eHYPE_discharge_model-mean_rcp45.pkl")
Q_all_rcp85 = pd.read_pickle("/path/to/Data_Ehype//eHYPE_discharge_model-mean_rcp85.pkl")
Q_all_scen = pd.concat([Q_all_rcp45, Q_all_rcp85], axis=1)
Q_all_scen = Q_all_scen.loc[:,cmt_avail]
Q_all_scen.columns = Q_all_scen.columns.droplevel(1)

daily_reservoir_scen = (Q_all_scen.stack(1)*scaling).unstack(1)

#%%
ann_mean_total_power = daily_power_reservoir.resample("A").mean().sum(axis=1)
ann_mean_total_power.name="historical"
ann_mean_total_power_scen = daily_reservoir_scen.groupby(axis=1, level=1).sum().resample("A").mean()

ann_mean_total_power_hist_scen = ann_mean_total_power_scen.join(ann_mean_total_power.astype(np.float32), how="outer")
ann_mean_total_power_hist_scen["rcp85"].fillna(ann_mean_total_power_hist_scen["historical"],inplace=True)
ann_mean_total_power_hist_scen["rcp45"].fillna(ann_mean_total_power_hist_scen["historical"],inplace=True)
ann_mean_rel_power_hist_scen = ann_mean_total_power_hist_scen/ann_mean_total_power.mean()


ann_mean_rel_power = ann_mean_total_power/ann_mean_total_power.mean()
ann_mean_rel_power_scen = ann_mean_total_power_scen/ann_mean_total_power.mean()

#%%

prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

ax = ann_mean_rel_power.plot(color=colors[0], label="historical")
ann_mean_rel_power_hist_scen.loc[slice("2005-12","2100-12-31"),["rcp45","rcp85"]].plot(color=colors[1:3], ax=ax)
ann_mean_rel_power_hist_scen.rolling(21).mean().loc[:,["historical"]].plot(color=colors[0], ax=ax, label='_nolegend_')
ann_mean_rel_power_hist_scen.rolling(21).mean().loc[slice("2005-12","2100-12-31"),["rcp45","rcp85"]].plot(color=colors[1:3], ax=ax, label='_nolegend_')


plt.xlabel("year")
plt.ylabel("relative annual energy production [1]")

plt.ylim([0.8,1.25])

plt.title("Relative annual enery production from reservoir plants")

plt.grid()

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles[:3], labels[:3])
plt.savefig("/path/to/figures/Relative_energy_production_reservoir_all.png")
plt.show()

#%%
daily_power_reservoir.columns = pd.MultiIndex.from_product([daily_power_reservoir.columns, ["historical"]])

daily_reservoir_all = daily_power_reservoir.join(daily_reservoir_scen, how="outer")
daily_reservoir_all.to_pickle("/path/to/Data_Ehype//daily_power_reservoir_all.pkl")
scaling.to_pickle("/path/to/Data_Ehype/scaling_reservoir_all.pkl")

#%%
plants_out = daily_reservoir_all.columns.get_level_values(0)

df_meta = pp_gdf_filt.loc[plants_out,:].drop('Unnamed: 0', axis=1)
df_meta['k'] = scaling.loc[plants_out]




# Export data
# TODO check variable names
df_meta.to_file("/path/to/Data_deliver/Data_Hydro/Metadata_reservoir.shp")
df_meta.to_csv("/path/to/Data_deliver/Data_Hydro/Metadata_reservoir.csv", sep=";")

Q_all.loc[:,plants_out].to_csv("/path/to/Data_deliver/Data_Hydro/EHype_ensemble_daily_mean_historic_2071-2005_reservoir.csv", float_format = "%5.1f", sep=";")
Q_all_scen.loc[:, plants_out].xs("rcp45", 1, 1).to_csv("/path/to/Data_deliver/Data_Hydro/EHype_ensemble_daily_mean_EUR-11_ICHEC-EC-EARTH_rcp45_r12i1p1_KNMI-RACMO22E_2006-2100_reservoir.csv", float_format = "%5.1f", sep=";")
Q_all_scen.loc[:, plants_out].xs("rcp85", 1, 1).to_csv("/path/to/Data_deliver/Data_Hydro/EHype_ensemble_daily_mean_EUR-11_ICHEC-EC-EARTH_rcp85_r12i1p1_KNMI-RACMO22E_2006-2100_reservoir.csv", float_format = "%5.1f", sep=";")

daily_reservoir_all.loc[:, plants_out].xs("historical", 1, 1).dropna(0).to_csv("/path/to/Data_deliver/Data_Hydro/Power_daily_mean_historical_1971-2005_reservoir.csv", float_format = "%5.1f", sep=";")
daily_reservoir_all.loc[:, plants_out].xs("rcp45", 1, 1).dropna(0).to_csv("/path/to/Data_deliver/Data_Hydro/Power_daily_mean_EUR-11_ICHEC-EC-EARTH_rcp45_r12i1p1_KNMI-RACMO22E_2006-2100_reservoir.csv", float_format = "%5.1f", sep=";")
daily_reservoir_all.loc[:, plants_out].xs("rcp85", 1, 1).dropna(0).to_csv("/path/to/Data_deliver/Data_Hydro/Power_daily_mean_EUR-11_ICHEC-EC-EARTH_rcp85_r12i1p1_KNMI-RACMO22E_2006-2100_reservoir.csv", float_format = "%5.1f", sep=";")


#%% aggrgate countrywise

countries = pp_gdf.loc[daily_all.columns.get_level_values(0), "country_code"]

daily_reservoir_all.columns = pd.MultiIndex.from_arrays([daily_all.columns.get_level_values(0),
                                               daily_all.columns.get_level_values(1),
                                               countries])

power_cny_agg = daily_all.groupby(axis=1,level=[1,2]).sum().replace(0,np.nan)
power_cny_hist_mean = power_cny_agg["historical"].mean()
power_cny_rel = power_ann/power_hist_mean[power_ann.columns.get_level_values(1)].values

power_cny_rel.loc[:, plants_out].xs("historical", 1, 1).dropna(0).to_csv("/path/to/Data_deliver/Data_Hydro/Power_daily_mean_historical_1971-2005_reservoir_countries.csv", float_format = "%5.1f", sep=";")
power_cny_rel.loc[:, plants_out].xs("rcp45", 1, 1).dropna(0).to_csv("/path/to/Data_deliver/Data_Hydro/Power_daily_mean_EUR-11_ICHEC-EC-EARTH_rcp45_r12i1p1_KNMI-RACMO22E_2006-2100_reservoir_countries.csv", float_format = "%5.1f", sep=";")
power_cny_rel.loc[:, plants_out].xs("rcp85", 1, 1).dropna(0).to_csv("/path/to/Data_deliver/Data_Hydro/Power_daily_mean_EUR-11_ICHEC-EC-EARTH_rcp85_r12i1p1_KNMI-RACMO22E_2006-2100_reservoir_countries.csv", float_format = "%5.1f", sep=";")

