import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.stats as stat

import lib.funcs.perc_contributions_WRAP
import lib.funcs.dat_io as io
import lib.funcs.foodsupply_trajectory
import lib.funcs.name_alias
import lib.dat.colours
import lib.dat.food_commodity_seperation
import lib.dat.fodder_crops


def main(area_index):

    def error_rep(comm, type):
        print(f"Warning: KeyError in {__name__} - no {type} data for {comm} in {area}.")
        return

    def line(dat, dat_start, dat_end):
        years = np.arange(dat_start, dat_end + 1, 1)
        line_params = stat.linregress(years, dat)
        return line_params

    land_use_dat = pd.read_csv("data\\Inputs_LandUse_E_All_Data_NOFLAG.csv",
                                encoding = "latin-1",
                                index_col = ["Area", 'Item Code',
                                'Item', 'Element Code', 'Element', 'Unit'])
    land_use_dat = io.re_index_area(land_use_dat)

    fodder_crops_list = lib.dat.fodder_crops.fodder_crops
    fodder_crops_properties = pd.read_excel("lib\\dat\\fodder_product_properties.xlsx", index_col = 0)

    for continent in area_index.Continent.unique():

        for region in area_index[area_index.Continent == continent].Region.unique():

            path = f"data\\{continent}\\{region}"
            FBS_dat = io.load(path, f"FoodBalanceSheets_E_{region}")
            production_dat = io.load(path, f"Production_Crops_E_{region}")

            for area in area_index[area_index.Region == region].index.to_list():

                area_land_use = land_use_dat.xs(area, level = "Area")

                # load yield projections
                crop_yield_projection = io.load(path, f"yield_production\\yield_projection_{area}") # tonnes/ha/year

                # load crop production demand
                crop_production_demand = io.load(path, f"production\\production_mass_vegetal_for_human_{area}") #kg/year
                crop_production_demand = crop_production_demand / 1000 # tonnes/year
                crop_production_demand = crop_production_demand[np.logical_not(crop_production_demand.index.isin(["Alcohol"]))]

                crop_yield_projection = crop_yield_projection.replace(0, np.nan)

                crop_land_area_projection = crop_production_demand / crop_yield_projection.values # ha

                # do yield projection for fodder
                fodder_production_quota = io.load(path, f"livestock\\fodder_production_quota_{area}") # MJ / year
                production_hist = production_dat.xs(area, level = "Area").xs("Production", level = "Element")
                yield_hist = production_dat.xs(area, level = "Area").xs("Yield", level = "Element") * 1E-04

                # create receiving dataframe
                fodder_yield = pd.DataFrame(index = np.arange(2013, 2051, 1), columns = fodder_crops_list).T

                for crop in fodder_crops_list:

                    if crop not in crop_yield_projection.index.to_list():

                        crop_conv = lib.funcs.name_alias.conv(crop)

                        try:
                            if crop_conv == "Sugar beet":

                                fodder_yield.loc["Sugar beet"] = crop_yield_projection.loc["Sugar Crops"]

                            else:

                                data = yield_hist.xs(crop_conv, level = "Item")

                                years = 30

                                line_params = stat.linregress(np.arange(2018 - years, 2018, 1),
                                                            data.values[0][-years:])

                                slope = line_params[0]
                                intercept = line_params[1]
                                r = line_params[2]
                                p = line_params[3]

                                lin = lambda x: (x * slope) + intercept
                                val_2050 = lin(2050)
                                val_start = data.iloc[:, -5:].mean(axis = 1).values[0]
                                val_min = data.min(axis = 1).values[0]

                                lin2 = lambda x: val_start + ((x - 2013) * ((val_2050 - val_start) / (2050 - 2013)))
                                if r**2 > 0.2 and p < 0.05 and slope > 0:
                                    output = [val_min if lin2(x) < val_min else lin2(x) for x in np.arange(2013, 2050 + 1, 1)]
                                else:
                                    output = data.loc[crop].astype(float).iloc[-10:-1].mean()

                                fodder_yield.loc[crop] = output
                        except KeyError:
                            print(f"KeyError in {__name__}; no yield data for {crop} for {area}, using fallback yield value")

                    else:
                        fodder_yield.loc[crop] = crop_yield_projection.loc[crop]
                    if fodder_yield.loc[crop].isnull().values.any() == True and fodder_production_quota.loc[crop].sum(axis = 0) > 0:
                        try:
                            fodder_yield.loc[crop] = fodder_crops_properties.loc[crop]["fallback_yield"]
                        except KeyError:
                            fodder_yield.loc[crop] = fodder_crops_properties.loc[crop_conv]["fallback_yield"]

                fodder_production_quota = fodder_production_quota.fillna(0) # MJ / year
                fodder_production_quota_mass = pd.DataFrame(columns = fodder_production_quota.columns, index = fodder_production_quota.index)

                for crop in fodder_crops_list:

                    try:
                        energy_density = fodder_crops_properties.loc[crop_conv]["energy_density"] # MJ / kg
                    except KeyError:
                        energy_density = fodder_crops_properties.loc[crop]["energy_density"] # MJ / kg
                    energy_density = energy_density * 1000 # MJ / tonne
                    fodder_production_quota_mass.loc[crop] = fodder_production_quota.loc[crop] / energy_density # tonnes / year

                fodder_land_area_projection = fodder_production_quota_mass / fodder_yield # ha



                FAO_comm_groups = pd.read_csv(f"data\\FAOSTAT_data_11-6-2019_commoditygroups.csv")

                # returns the list of items within a commodity group
                def comm_group(metadata, group):
                    data = metadata.loc[metadata["Item Group"] == group]
                    return data

                fcs = lib.dat.food_commodity_seperation
                # These are taken as individual commodities
                staple_crops = fcs.staple_crops
                # Aggregated groups
                vegetal_prods_grouped_list = fcs.vegetal_prods_grouped_list
                # grouped as luxuries
                luxuries = fcs.luxuries
                # alcohol
                alcohol = fcs.alcohol

                index_list  = staple_crops + vegetal_prods_grouped_list\
                            + ["Luxuries (excluding Alcohol)"]\
                            + ["Other"]
                # group commodities
                area_agg = pd.DataFrame(index = index_list, columns = data.columns.to_list())
                area_dat = production_dat.xs("Area harvested", level = "Element")
                area_area_dat = area_dat.xs(area, level = "Area")

                # Lots of naming conventions changed about here (production dataset doesn't align with FBS).
                for crop in staple_crops:
                    if crop == "Rape and Mustardseed":
                        div = 0.5
                        try:
                            area_agg.loc[crop] = area_area_dat.xs("Rapeseed", level = "Item").values
                        except KeyError:
                            div = 1
                        try:
                            area_agg.loc[crop] = np.sum([area_agg.loc[crop], area_area_dat.xs("Mustard seed", level = "Item").values[0]], axis = 0)
                        except KeyError:
                            div = 1
                        area_agg.loc[crop] = area_agg.loc[crop] * div
                    else:
                        try:
                            area_agg.loc[crop] = area_area_dat.xs(crop, level = "Item").values
                        except KeyError:
                            try:
                                area_agg.loc[crop] = area_area_dat.xs(lib.funcs.name_alias.conv(crop), level = "Item").values
                            except KeyError:
                                error_rep(lib.funcs.name_alias.conv(crop), "area")

                production_groups = pd.read_csv("data\\FAOSTAT_data_1-21-2020_production_dat_groups.csv")

                # Fruit
                fruit_list = comm_group(production_groups, "Fruit Primary")["Item"].to_list()
                fruit_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(fruit_list)]
                area_agg.loc["Fruits - Excluding Wine"] = fruit_area.sum(axis = 0)

                # Cereal (- wheat, - rice, -maize)
                cereal_list = comm_group(production_groups, "Cereals, Total")["Item"].to_list()
                cereal_list.remove("Wheat")
                cereal_list.remove("Rice, paddy")
                cereal_list.remove("Maize")
                cereal_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(cereal_list)]
                area_agg.loc["Cereals - Excluding Beer"] = cereal_area.sum(axis = 0)

                # Pulses
                pulses_list = comm_group(production_groups, "Pulses, Total")["Item"].to_list()
                pulses_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(pulses_list)]
                area_agg.loc["Pulses"] = pulses_area.sum(axis = 0)

                # Spices
                spices_list = ["Pepper (piper spp.)", "Cloves", "Spices nes", "Spices, nes", "Anise, badian, fennel, coriander"]
                spices_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(spices_list)]
                area_agg.loc["Spices"] = spices_area.sum(axis = 0)

                # Starchy roots (- potato, - cassava)
                roots_list = comm_group(production_groups, "Roots and Tubers, Total")["Item"].to_list()
                roots_list.remove("Potatoes")
                roots_list.remove("Cassava")
                roots_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(roots_list)]
                area_agg.loc["Starchy Roots"] = roots_area.sum(axis = 0)

                # Sugar crops
                sugar_list = ["Sugar cane", "Sugar beet"]
                sugar_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(sugar_list)]
                area_agg.loc["Sugar & Sweeteners"] = sugar_area.sum(axis = 0)

                # Vegetable Oils -  an "effective yield"; these crops are the same as oilcrops
                #                   in these cases a factor is applied for oil extraction,
                #                   resulting in a typically lower effective yield.
                oilcrop_list = comm_group(production_groups, "Oilcrops, Oil Equivalent")
                droplist = ["Soyabeans", "Oil, Palm", "Sunflower seed", "Rapeseed", "Mustard seed", "Soybeans"]
                oilcrop_list = oilcrop_list[np.logical_not(oilcrop_list["Item"].isin(droplist))]
                oilcrop_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(oilcrop_list["Item"].to_list())]
                area_agg.loc["Oilcrops"] = oilcrop_area.sum(axis = 0)

                # Vegetables
                vegetable_list = comm_group(production_groups, "Vegetables Primary")["Item"].to_list()
                vegetable_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(vegetable_list)]
                area_agg.loc["Vegetables"] = vegetable_area.sum(axis = 0)

                # luxuries
                luxury_list = ["Cocoa, beans", "Coffee, green", "Tea", "Maté"]
                luxury_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(luxury_list)]
                area_agg.loc["Luxuries (excluding Alcohol)"] = luxury_area.sum(axis = 0)

                # "Other"
                big_list    = staple_crops + fruit_list + cereal_list + oilcrop_list["Item"].to_list()\
                            + sugar_list + pulses_list + spices_list + roots_list\
                            + vegetable_list + luxury_list #+ alcohol_list

                others_list  = ["Flax fibre and tow", "Fibre crops nes", "Tobacco, unmanufactured", "Seed cotton"]
                others_area = area_area_dat.iloc[area_area_dat.index.get_level_values("Item").isin(others_list)]
                area_agg.loc["Other"] = others_area.sum(axis = 0)

                conv = {"Barley and products" : "Barley",
                        "Sorghum and products": "Sorghum",
                        "Soyabeans" : "Soybeans"
                        }

                crop_land_area_projection = crop_land_area_projection.fillna(0)
                fodder_land_area_projection = fodder_land_area_projection.fillna(0)
                area_agg = area_agg.fillna(0)
                # print(area_agg)
                # print("precrop", crop_land_area_projection)
                # print("prefod", fodder_land_area_projection)

                for crop in crop_land_area_projection.index.to_list():
                    pc_2017 = crop_land_area_projection.loc[crop][2017]
                    hc_2017 = area_agg.loc[crop]["Y2017"]
                    if crop in fodder_land_area_projection.index.to_list():
                        pf_2017 = fodder_land_area_projection.loc[crop][2017]
                        f_ratio_2017 = pf_2017 / (pf_2017 + pc_2017)
                        hist_ratio = hc_2017 / (pf_2017 + pc_2017)

                        # print(crop, f_ratio_2017, 1-f_ratio_2017, "f_ratio_2017")
                        # print(crop, f"{hc_2017} / ({pf_2017} + {pc_2017}) = {hist_ratio}", "x")
                        crop_land_area_projection.loc[crop] = crop_land_area_projection.loc[crop] * (hist_ratio)#* (1 - f_ratio_2017))
                        fodder_land_area_projection.loc[crop] = fodder_land_area_projection.loc[crop] * (hist_ratio)# * f_ratio_2017)
                    else:
                        try:
                            hist_ratio = hc_2017 / pc_2017
                        except ZeroDivisionError:
                            hist_ratio = 1
                        crop_land_area_projection.loc[crop] = crop_land_area_projection.loc[crop] * hist_ratio
                        # print(crop, hist_ratio, "y")

                for crop in fodder_land_area_projection.index.to_list():
                    if crop not in crop_land_area_projection.index.to_list():
                        fodder_land_area_projection.loc[crop] = fodder_land_area_projection.loc[crop].fillna(0)
                        pf_2017 = fodder_land_area_projection.loc[crop][2017]
                        if pf_2017 > 0:
                            hc_2017 = area_area_dat.xs(conv[crop], level = "Item")["Y2017"]
                            hist_ratio = (hc_2017 / pf_2017).values[0]
                        elif crop == "Other feed":
                            hist_ratio = 0
                        else:
                            hist_ratio = 1
                        fodder_land_area_projection.loc[crop] = fodder_land_area_projection.loc[crop] * hist_ratio
                        # print(crop, hist_ratio, "z")

                # print("postcrop", crop_land_area_projection)
                # print("postfod", fodder_land_area_projection)

                # import random
                # for crop in crop_land_area_projection.index.to_list():
                #     color = tuple(random.choice(range(32,256,32))/256 for _ in range(3))
                #
                #     plt.plot(np.arange(1961, 2018, 1), area_agg.loc[crop].values, color = color)
                #     # if crop in fodder_land_area_projection.index.to_list():
                #     vals = crop_land_area_projection.loc[crop].values
                #     plt.plot(np.arange(2013, 2051, 1), vals, color = color, label = crop)
                # plt.legend()
                #     # else:
                #     #     plt.plot(np.arange(2013, 2051, 1), crop_land_area_projection.loc[crop].values, color = color, label = crop)
                # plt.show()

                io.save(f"{path}\\land_use", f"fodder_area_{area}", fodder_land_area_projection)
                io.save(f"{path}\\land_use", f"crop_area_{area}", crop_land_area_projection)
