#!/usr/bin/env python

"""Data Analysis & Plotting - B (Daily Prevalence Estimation Plots)"""

import os
import sys
import random
import pickle
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.transforms
from matplotlib import gridspec
import tables as h5
import numpy as np
import geopandas as gpd
import imageio
import re
import shutil


class CreateAnimationPlots():
    """Process data output from the simulation. To include everything ranging from
    the importing of shapefiles, to saving of map figures."""
    def __init__(self, site, model_dir, figsize_a4, dpi, prev_res, jitter,
                 num_days):
        self.site = site
        self.figsize_a4 = [float(i) for i in figsize_a4[1:-1].split(',')]
        self.dpi = int(dpi)
        self.prev_res = int(prev_res)
        self.jitter = int(jitter)
        self.num_days = int(num_days)

        self.model_dir = f"../data/models/{site}/" + model_dir
        self.sim_dir = f"../data/simulation/{site}/" + model_dir
        self.save_dir = f"../data/analysis/{site}/{model_dir}/"
        self.temp_dir = self.save_dir + 'temp_prev/'
        self.input_file_path = self.model_dir + "model_output.h5"
        self.results_file_path = self.sim_dir + "sim_output.h5"
        
        if os.path.exists(self.save_dir) == False:
            self.generate_output_directories(self.save_dir, self.temp_dir)
        if not os.path.isdir(self.temp_dir):
            os.makedirs(self.temp_dir)

        # load site poly
        self.site_poly = gpd.read_file("../data/init_raw_data/study_polygons/study_polygons.gpkg",
                                       layer=site,
                                       driver="GPKG").to_crs(3857)

        self.left, self.right, self.top, self.bottom = self.establish_end_figure_dims()

        self.odds_range = self.collate_prev_limits()

        self.ind_x, self.ind_y = self.get_ind_coords()

    def get_ind_coords(self):
        # open files
        input_file = h5.open_file(self.input_file_path, mode='r')

        individuals = input_file.root.demography.people[:]
        households = input_file.root.networks.houses[:]

        # apply jitter to individuals coordinates
        ind_x = np.array([households[i['household']]['lon'] + random.uniform(-self.jitter,self.jitter) for i in individuals])
        ind_y = np.array([households[i['household']]['lat'] + random.uniform(-self.jitter,self.jitter) for i in individuals])

        input_file.close()

        return ind_x, ind_y

    def create_animation_plot(self, day_no, run_no):
        # open output file
        output_file = h5.open_file(self.results_file_path, mode='r')

        count = 0
        for t in output_file.root.inf_data:
            if count == run_no:
                break
            count += 1
        model_output = t[day_no, :]

        # set up the animation axis
        fig, axes = self.create_animation_axis(self.ind_x, self.ind_y)

        # add infections to the plot
        infected = np.where(model_output==2)[0] # 2 is infected 
        axes[1].scatter(self.ind_x[infected], self.ind_y[infected], s=40, c='r', edgecolor='k', linewidth=0.5, zorder=10)

        # update labels & add the kriging data
        self.label2.set_text(f"Day: {day_no}")    
        axes[0] = self.site_poly.plot(ax=axes[0], color=(0.0,0.0,0.0,0.0), edgecolor=(0.0,0.0,0.0,1.0)) 
        axes[0].tick_params(axis='both', which='major', labelsize=6)
        with open(self.temp_dir+f"/prev_{day_no}.p", "rb") as f:
            krige_data = pickle.load(f)
        axes[0].imshow(krige_data, origin='lower', extent=[self.left, self.right, self.bottom, self.top], 
                            vmin=self.odds_range[0], 
                            vmax=self.odds_range[1],
                            cmap=matplotlib.cm.coolwarm, 
                            alpha=0.75)
        axes[0].set_xlim(self.left, self.right)
        axes[0].set_ylim(self.top, self.bottom)
        axes[0].tick_params(axis='x', bottom=True, top=False, labelbottom=True, labeltop=False)
        axes[1].set_xlim(self.left, self.right)
        axes[1].set_ylim(self.top, self.bottom)               
        vmin, vmax = self.odds_range
        cbar = fig.colorbar(matplotlib.cm.ScalarMappable(norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), 
                                                         cmap=matplotlib.cm.coolwarm), cax=axes[2], orientation='horizontal', alpha=0.75)
        cbar.set_label('Ordinary Krige (spherical model) of log odds (col:uncol)')
        cbar.set_ticks([])

        ####### SAVE PLOT
        fig.savefig(self.temp_dir+'/'+str(day_no)+'.png')
        plt.close()

        # close file
        output_file.close()

    def animate(self):
        """Create animation in gif or mp4 format, highlighting the changing levels
        of contamination in the study areas, who caused the contamination, and which network 
        the contamination occured on."""
        fig = plt.figure()
        for file_ in os.listdir(self.temp_dir):
            if re.search('[a-zA-Z]', os.path.splitext(file_)[0]):
                os.remove(self.temp_dir+'/'+file_)

        file_names = sorted(os.listdir(self.temp_dir),key=lambda x: int(os.path.splitext(x)[0]))
        # fps for video
        fps = 1 # half speed = 0.5, double speed = 2 etc. Remember issues with -r argument (technically output fps, but seems to be dictated by this <- input fps (even though png)?)
        with imageio.get_writer(self.save_dir+f"animation/clip.mp4", 
                                macro_block_size=1, 
                                ffmpeg_log_level='error', 
                                format="FFMPEG", 
                                fps=fps, 
                                pixelformat="yuv420p", 
                                output_params=["-c:v", "libx264", "-r", "4"]) as writer:
            for file in file_names:
                writer.append_data(imageio.imread(self.temp_dir+'/'+file))  
        plt.close()


        if os.path.exists(self.temp_dir):
            shutil.rmtree(self.temp_dir) 


    def establish_end_figure_dims(self):
        """Generate template plot. Add shapefiles to figures. Record data dimensions required."""
        fig = plt.figure(figsize=self.figsize_a4, dpi=self.dpi)
        gs = gridspec.GridSpec(2, 2, fig, wspace=0.15, hspace=0.02, width_ratios=[1,1], height_ratios=[1,0.01])
        ax1 = plt.subplot(gs[0,0])   # left
        ax2 = plt.subplot(gs[0,-1])  # right
        ax3 = plt.subplot(gs[-1,0])  # colourbar
        gs.tight_layout(fig)
        fig.suptitle('Template Title', fontsize=14)
        ax1 = self.site_poly.plot(ax=ax1) 
        ax1.tick_params(axis='both', which='major', labelsize=6)
        ax2 = self.site_poly.plot(ax=ax2) 
        ax2.tick_params(axis='both', which='major', labelsize=6)
        axes = [ax1, ax2, ax3] 
        left, right = axes[1].get_xlim()
        top, bottom = axes[1].get_ylim()
        plt.close(fig)

        return (left, right, top, bottom)

    def collate_prev_limits(self):
        maxs = []
        mins = []
        for d in range(self.num_days):
            with open(self.temp_dir+f'/max_prev{d}.txt') as f:
                first_line = float(f.readline())
                maxs.append(first_line)
            with open(self.temp_dir+f'/min_prev{d}.txt') as f:
                first_line = float(f.readline())
                mins.append(first_line)

        max_prev = max(maxs)
        min_prev = min(mins)
        
        return (min_prev, max_prev)

    def create_animation_axis(self, xs, ys):
        """Generate plot. Add shapefiles to figures."""
        # Plot figure with subplots of different sizes
        fig = plt.figure(figsize=self.figsize_a4, dpi=self.dpi)
        gs = gridspec.GridSpec(2, 2, fig, #left=0.1, right=0.95, bottom=0.04, top=0.96,
                               wspace=0.15, hspace=0.02, 
                               width_ratios=[1,1], height_ratios=[1,0.01])

        ax1 = plt.subplot(gs[0,0])   # left
        ax2 = plt.subplot(gs[0,-1])  # right
        ax3 = plt.subplot(gs[-1,0])
        gs.tight_layout(fig)
        ax1 = self.site_poly.plot(ax=ax1, color=(0.0,0.0,0.0,0.0), edgecolor=(0.0,0.0,0.0,1.0)) 
        ax1.tick_params(axis='both', which='major', labelsize=6)
        ax2 = self.site_poly.plot(ax=ax2, color='w', edgecolor='k') 
        ax2.tick_params(axis='both', which='major', labelsize=6)

        axes = [ax1, ax2, ax3]

        self.label1 = axes[0].text(0.2, 0.95, 'Contamination Prevalence', fontsize=12, transform=axes[0].transAxes)
        self.label2 = axes[1].text(0.4, 0.95, 'Day: 1', fontsize=18, transform=axes[1].transAxes) 
        
        axes[1].scatter(xs, ys, s=10, c='w', edgecolor='k', linewidth=0.5, zorder=10)

        return fig, axes
