#!/usr/bin/env python

"""Simulator"""
import os
import numpy as np
import tensorflow as tf
tf.config.threading.set_inter_op_parallelism_threads(4)
tf.config.threading.set_intra_op_parallelism_threads(4)
import tensorflow_probability as tfp
import pickle
import tables as h5
from scipy import sparse

from sim.tf_funcs import State, EventCache
from sim.tf_funcs import sparse_np2tf, sparse_np2tf_proximity, tf_iterate



class Simulator():
    """Class to run a simple AMR colonisation model using an array approach.
    Assumes daily timesteps (for rate and prob calcs)."""
    def __init__(self, site, model_dir, num_infected, latency_mean, duration_mean, 
                 net_rate, forcing_rate, forcing_delta, num_runs):

        self.site = site
        self.model_dir = f"../data/models/{site}/" + model_dir
        self.save_dir = f"../data/simulation/{site}/" + model_dir

        if not os.path.isdir(self.save_dir):
            os.makedirs(self.save_dir)

        # load the model
        with open(self.model_dir + 'model.p', 'rb') as f:
            model = pickle.load(f)

        # assign paths to files
        self.out_file_path = self.save_dir+'sim_output.h5'
        # load networks
        self.forcing_delta = float(forcing_delta)
        self.networks, self.network_lookup = self.load_networks(model.save_dir, self.forcing_delta)

        # load schedule
        self.schedule = pickle.load(open(model.save_dir+'schedule.p', 'rb'))

        # generate 1D array of pop states. Record pop size and colonised ids.
        self.num_people = model.num_people
        self.pop_state = tf.convert_to_tensor(np.zeros((self.num_people,)), dtype=tf.float32)
        self.num_infected = int(num_infected)

        # define params
        self.params = {'latency_mean':int(latency_mean), 'duration_mean':int(duration_mean),
                       'net_rate':float(net_rate), 'forcing_rate': float(forcing_rate)}

        # generate simulation timestep == 1. 
        self.state, self.event_cache, self.net_select, self.start_of_day = self.setup_sim()

        self.num_runs = int(num_runs)

    def load_networks(self, model_dir, forcing_delta):
        """Read networks from file and convert to tensors.
           Returns an ND sparse tensor, with the networks stacked up into an NxMxM shape,
           and a lookup dictionary for the network type and it's index
        """
        home_net = sparse_np2tf(sparse.load_npz(open(model_dir+'houses_net.npz', 'rb')))
        forcing_matrix = sparse_np2tf_proximity(sparse.load_npz(open(model_dir+'proximity.npz', 'rb')), forcing_delta)

        networks = [
                    home_net,
                    forcing_matrix,
                    None
                    ]


        net_lookup = {'home_net': 0, 'Forcing': 1, 'Skip': 2}

        return networks, net_lookup

    def setup_sim(self):
        """Setup a simulation"""

        inf = self.pop_state

        # create some latency & recovery times for current/next infection.
        latency = tf.cast(tfp.distributions.Geometric(probs=1/self.params['latency_mean']).sample(sample_shape=[self.num_people]), tf.dtypes.int32)
        recovery = tf.cast(tfp.distributions.Geometric(probs=1/self.params['duration_mean']).sample(sample_shape=[self.num_people]), tf.dtypes.int32)

        # NB S: 0, L: 1, I: 2, R: 3
        s = tf.convert_to_tensor([1 for i in self.pop_state], dtype=tf.float32)
        l = tf.convert_to_tensor([0 for i in self.pop_state], dtype=tf.float32)
        i = tf.convert_to_tensor([0 for i in self.pop_state], dtype=tf.float32)
        r = tf.convert_to_tensor([0 for i in self.pop_state], dtype=tf.float32)
        
        net = tf.convert_to_tensor([0 for i in self.pop_state], dtype=tf.float32)

        # store the data
        state = State(s=s, l=l, i=i, r=r, net=net)
        event_cache = EventCache(latency=latency, recovery=recovery)

        # convert schedule  to arrays
        net_select = []
        start_of_day = []

        for s in self.schedule.values():
            if s['tick_type'] == 'Skip':
                net_select.append(self.network_lookup['Skip'])
            else:    
                net_select.append(self.network_lookup[s['network']])

            if s['day_tick'] == 0:
                start_of_day.append(1)
            else:
                start_of_day.append(0)

        net_select = tf.convert_to_tensor(net_select, dtype=tf.int64)
        start_of_day = tf.convert_to_tensor(start_of_day, dtype=tf.int32)

        return state, event_cache, net_select, start_of_day

    def create_infected(self, num_infected):
        """Updates self.pop_state with infected people"""
        self.pop_state = self.pop_state.numpy() * 0 # set all to zero
        infected_ids = np.random.choice(self.num_people, num_infected, replace=False)
        # update pop_state with infections
        for i in infected_ids:
            self.pop_state[i] = 1
        self.pop_state = tf.convert_to_tensor(self.pop_state, dtype=tf.float32)

        # update state / recovery times based on the infected people
        self.state = self.state._replace(s=(self.state.s*0 + 1)-self.pop_state)
        self.state = self.state._replace(i=(self.state.i*0)+self.pop_state)

        latency = tf.cast(tfp.distributions.Geometric(probs=1/self.params['latency_mean']).sample(sample_shape=[self.num_people]), tf.dtypes.int32)
        recovery = tf.cast(tfp.distributions.Geometric(probs=1/self.params['duration_mean']).sample(sample_shape=[self.num_people]), tf.dtypes.int32)

        # Set historical infection time for those already infected.
        inf_mask = tf.equal(self.pop_state, 1)
        rand_int =  tf.random.uniform(shape=[len(inf_mask)], minval=1, ## because unknown contamination time: pick a random number in range [1, 7]
                                     maxval=7, dtype=tf.dtypes.int32)
        recovery = tf.where(inf_mask, rand_int, recovery)

        self.event_cache = self.event_cache._replace(latency=latency, recovery=recovery)

    def iterate(self): 
        """Runs the simulation on the network."""

        out_file = h5.open_file(self.out_file_path, mode='w')
        out_file.create_group("/", 'inf_data')
        out_file.create_group("/", 'inf_net')

        for i in range(self.num_runs):

            # reset infections
            self.create_infected(self.num_infected)

            # tf.profiler.experimental.start('logdir') ## tensorboard --logdir logdir
            # print('profiler starting')

            inf_data, net_infect = tf_iterate(net_select=self.net_select,
                                              start_of_day = self.start_of_day,
                                              state=self.state,
                                              event_cache=self.event_cache,
                                              params=self.params,
                                              networks=self.networks)

            # tf.profiler.experimental.stop()   

            # print('profiler stopped')

            inf_data = inf_data.numpy().reshape((sum(self.start_of_day), self.num_people))
            self.output_sim(inf_data, net_infect.numpy(), out_file, f'sim_output_{i}')

        out_file.close()

    def output_sim(self, inf_data, net_infect, out_file, table_name='sim_output'):
        """
        Output large sim tensor to file.
        Stacked 1D daily output tensors, converted to numpy, saved to hdf5.
        """
        atom = h5.UInt8Atom()
        filters = h5.Filters(complevel=5, complib='zlib')

        ca = out_file.create_carray(out_file.root.inf_data, 
                                    table_name, 
                                    atom, 
                                    inf_data.shape,
                                    filters=filters)

        ca[:] = inf_data

        ca = out_file.create_carray(out_file.root.inf_net, 
                                    table_name, 
                                    atom, 
                                    net_infect.shape,
                                    filters=filters)

        ca[:] = net_infect
    