#!/usr/bin/env python

"""Data Analysis & Plotting - A (Daily Prevalence Data Estimation)"""

import os
import pickle
from math import log
import shutil
import tables as h5
import numpy as np
import matplotlib.pyplot as plt 
import geopandas as gpd
from matplotlib import gridspec
from pykrige.ok import OrdinaryKriging


class CreatePrevalenceEstimations():
    """Process data output from the simulation. To include everything ranging from
    the importing of shapefiles, to saving of map figures."""
    def __init__(self, site, model_dir, figsize_a4, dpi, prev_res, jitter,
                 num_days):

        self.site = site
        self.figsize_a4 = [float(i) for i in figsize_a4[1:-1].split(',')]
        self.dpi = int(dpi)
        self.prev_res = int(prev_res)

        self.model_dir = f"../data/models/{site}/" + model_dir
        self.sim_dir = f"../data/simulation/{site}/" + model_dir
        self.save_dir = f"../data/analysis/{site}/{model_dir}/"
        self.temp_dir = self.save_dir + 'temp_prev/'
        self.input_file_path = self.model_dir + "model_output.h5"
        self.results_file_path = self.sim_dir + "sim_output.h5"
        
        if os.path.exists(self.save_dir) == False:
            self.generate_output_directories(self.save_dir, self.temp_dir)
        if not os.path.isdir(self.temp_dir):
            os.makedirs(self.temp_dir)

        # load site poly
        self.site_poly = gpd.read_file("../data/init_raw_data/study_polygons/study_polygons.gpkg",
                                       layer=site,
                                       driver="GPKG").to_crs(3857)
        self.figure_dims = self.establish_end_figure_dims()

    def generate_output_directories(self, output_directory, temp_dir):
        """Create output directory."""
        os.makedirs(output_directory+'resource')
        os.makedirs(output_directory+'demography')
        os.makedirs(output_directory+'epidemiology')
        os.makedirs(output_directory+'animation')
        os.makedirs(output_directory+'networks')

    def establish_end_figure_dims(self):
        """Generate template plot. Add shapefiles to figures. Record data dimensions required."""
        fig = plt.figure(figsize=self.figsize_a4, dpi=self.dpi)
        gs = gridspec.GridSpec(2, 2, fig, wspace=0.15, hspace=0.02, width_ratios=[1,1], height_ratios=[1,0.01])
        ax1 = plt.subplot(gs[0,0])   # left
        ax2 = plt.subplot(gs[0,-1])  # right
        ax3 = plt.subplot(gs[-1,0])  # colourbar
        gs.tight_layout(fig)
        fig.suptitle('Template Title', fontsize=14)
        ax1 = self.site_poly.plot(ax=ax1) 
        ax1.tick_params(axis='both', which='major', labelsize=6)
        ax2 = self.site_poly.plot(ax=ax2) 
        ax2.tick_params(axis='both', which='major', labelsize=6)
        axes = [ax1, ax2, ax3] 
        left, right = axes[1].get_xlim()
        top, bottom = axes[1].get_ylim()
        plt.close(fig)

        return (left, right, top, bottom)


    def create_log_odds_matrix(self, day_no, run_no):
        """Calc prevalence estimate based on log odds ."""
        # calc dimensions
        left, right, top, bottom = self.figure_dims
        plot_width_m =  abs(left - right)
        plot_height_m =  abs(top - bottom)
        required_vertical_units = int(plot_height_m/self.prev_res)
        required_horizontal_units = int(plot_width_m/self.prev_res)
        gridx = np.linspace(left,right,required_horizontal_units) # yaxis 
        gridy = np.linspace(bottom,top,required_vertical_units) # xaxis 
        current_inf_ids = []
        min_log_odds = None 
        max_log_odds = None 

        # open household data
        input_file = h5.open_file(self.input_file_path, mode='r')
        num_houses = len(input_file.root.networks.houses)
        house_idx = list(range(0,num_houses))

        # open results data
        output_file = h5.open_file(self.results_file_path, mode='r')
        count = 0
        for t in output_file.root.inf_data:
            if count == run_no:
                break
            count += 1
        model_output = t[day_no, :]

        # NB S: 0, L: 1, I: 2, R: 3
        inf_indices = np.where(model_output==2)[0]

        xs = []
        ys = []
        zs = []
        for h, house in enumerate(house_idx):
            temp_inf = [inf_p for inf_p in inf_indices if inf_p in input_file.root.networks.houses[house]['occupants']] ## list of inf for this h

            col_house = len(temp_inf)
            n_house = len(input_file.root.networks.houses[house]['occupants'])
            lon = input_file.root.networks.houses[house]['lon'] 
            lat = input_file.root.networks.houses[house]['lat']
            xs.append(lon)
            ys.append(lat)

            # "From Generalized Linear Models With Examples in R", P. Dunn pg 341
            empirical_odds = (col_house+0.5) / (n_house-col_house+0.5)
            # print(col_house, n_house, empirical_odds)
            log_odds = log(empirical_odds)
            zs.append(log_odds)

        # Create the ordinary kriging object. Required inputs are the X-coordinates of
        # the data points, the Y-coordinates of the data points, and the Z-values of the
        # data points. If no variogram model is specified, defaults to a linear variogram
        # model. If no variogram model parameters are specified, then the code automatically
        # calculates the parameters by fitting the variogram model to the binned
        # experimental semivariogram. The verbose kwarg controls code talk-back, and
        # the enable_plotting kwarg controls the display of the semivariogram.

        OK = OrdinaryKriging(xs, ys, zs, variogram_model='spherical',
                             verbose=False, enable_plotting=False)

        # Creates the kriged grid and the variance grid. Allows for kriging on a rectangular
        # grid of points, on a masked rectangular grid of points, or with arbitrary points.
        # (See OrdinaryKriging.__doc__ for more information.)

        z, ss = OK.execute('grid', gridx, gridy)
       
        max_log_odds = np.max(z)
        min_log_odds = np.min(z)
        with open(self.temp_dir+f"max_prev{day_no}.txt", "w") as max_out:
            max_out.write(str(max_log_odds)+'\n')
        with open(self.temp_dir+f"/min_prev{day_no}.txt", "w") as min_out:
            min_out.write(str(min_log_odds)+'\n')
        pickle.dump(z, open(self.temp_dir+f"/prev_{day_no}.p", 'wb')) 

        input_file.close()
        output_file.close()
