
import os
import collections
import numpy as np
import tensorflow as tf
tf.config.threading.set_inter_op_parallelism_threads(4)
tf.config.threading.set_intra_op_parallelism_threads(4)
import tensorflow_probability as tfp
from scipy import sparse


State = collections.namedtuple('State', ['s', 'l', 'i', 'r', 'net'])
EventCache = collections.namedtuple('EventCache', ['latency', 'recovery'])

tfd = tfp.distributions  # Tensorflow probability distributions sub-module

def sparse_np2tf(sparse_matrix):
    coo = sparse_matrix.tocoo()
    indices = np.stack([coo.row, coo.col], axis=-1)
    return tf.cast(tf.sparse.SparseTensor(indices=indices, values=coo.data, dense_shape=coo.shape), tf.float32)


def sparse_np2tf_proximity(sparse_proximity_matrix, forcing_delta):
    sparse_proximity_matrix.data = (forcing_delta**2/((forcing_delta**2+sparse_proximity_matrix.data**2)**1.5))
    coo = sparse_proximity_matrix.tocoo()
    indices = np.stack([coo.row, coo.col], axis=-1)
    return tf.cast(tf.sparse.SparseTensor(indices=indices, values=coo.data, dense_shape=coo.shape), tf.float32)


@tf.function
def tf_infect_net(network: np.array, net_id: int, state: State, rate, name=None):
    with tf.name_scope(name):
        infect_rate = rate * tf.sparse.sparse_dense_matmul(network, state.i[:, tf.newaxis])
        prob = tf.reshape(tfd.Bernoulli(probs=(1. - tf.math.exp(-infect_rate)), dtype=tf.float32).sample(), state.s.shape)
        new_latent = state.s * prob
        state = state._replace(l=new_latent + state.l)
        state = state._replace(s=state.s - new_latent)
        state = state._replace(net=new_latent*float(net_id) + state.net)
        return state


@tf.function
def tf_latency(state: State, event_cache: EventCache):
    mask = tf.equal(state.l, 1)
    event_cache = event_cache._replace(latency=tf.where(mask, event_cache.latency-1, event_cache.latency))
    return event_cache


@tf.function
def tf_infected_time(state: State, event_cache: EventCache):
    mask = tf.equal(state.i, 1)
    event_cache = event_cache._replace(recovery=tf.where(mask, event_cache.recovery-1, event_cache.recovery))
    return event_cache


@tf.function
def tf_update_state(state: State,
                    event_cache: EventCache,
                    params: dict):
    """Updates state `state`
    :param state: a list of 1-D tensors -- [S, L, I, R]
    :param param: a dict of parameters -- duration_mean, latency_mean
    :param event_cache: the queues for next events -- recovery, latency
    """
    # set latent to infected
    infected_ind = tf.cast(tf.logical_and(tf.equal(event_cache.latency, 0), tf.equal(state.l, 1)), tf.float32)  # collections.namedtuple
    state = state._replace(l=state.l-infected_ind)
    state = state._replace(i=state.i+infected_ind)

    # Set infected to uninfected
    recovery_ind = tf.cast(tf.logical_and(tf.equal(event_cache.recovery, 0), tf.equal(state.i, 1)), tf.float32)
    state = state._replace(i=state.i-recovery_ind)
    state = state._replace(r=state.r+recovery_ind)

    # Reset recovery times to random int in recovery range, i.e. uninfected
    new_recovery = tf.cast(tfp.distributions.Geometric(probs=1/params['duration_mean']).sample(sample_shape=[len(recovery_ind)]), tf.dtypes.int32)
    event_cache = event_cache._replace(recovery=tf.where(tf.equal(recovery_ind, 1), new_recovery, event_cache.recovery))

    new_latency = tf.cast(tfp.distributions.Geometric(probs=1/params['latency_mean']).sample(sample_shape=[len(recovery_ind)]), tf.dtypes.int32)
    event_cache = event_cache._replace(latency=tf.where(tf.equal(recovery_ind, 1), new_latency, event_cache.latency))

    return state, event_cache

@tf.function
def tf_iterate(net_select, # 1D Tensor (int32)
               start_of_day, # 1d Tensor (int32)
               state: State,
               event_cache: EventCache,
               params: dict,
               networks): # list
    """Runs the simulation on the network."""
    # initialise results - 1D tensor length num days x num people
    inf_data = tf.TensorArray(dtype=tf.float32, size=tf.reduce_sum(start_of_day))

    def body(t, day, state, event_cache, inf_data):
        if start_of_day[t] == 1:            
            # Required to enumerate the state each person is in
            state_enum = tf.reduce_sum(tf.cumsum([state.s, state.l, state.i, state.r],
                                                  exclusive=True, reverse=True, axis=0),
                                       axis=0)
            # record states for each person on this day
            inf_data = inf_data.write(day, state_enum, name="store_state_enum")

            # update latency/infected periods and infection states
            event_cache = tf_latency(state, event_cache)
            event_cache = tf_infected_time(state, event_cache)
            state, event_cache = tf_update_state(state, event_cache, params)

            day += 1

            # Local forcing transmission
            network_id = 1
            state = tf_infect_net(networks[network_id], network_id+1, state, params['forcing_rate'], name='local_forcing')

        # infect.
        # get the network ID for this tick
        network_id = tf.cast(tf.gather(net_select, t), tf.int32)

        # calculate new infections (n.b. 'network_id+1' is passed in, so that network ids are indexed from 1, to help with reporting)
        # IF ADDITIONAL NETWORKS ARE ADDED, PUT EXTRA ITEMS IN HERE TO REFLECT IT: list index must be hard-coded for each network type,
        # otherwise get "TypeError: list indices must be integers or slices, not Tensor   "
        state = tf.switch_case(network_id, [lambda: tf_infect_net(networks[0], network_id+1, state, params['net_rate'], name='home'),
                                            lambda: state,
                                            lambda: state])


        return t+1, day, state, event_cache, inf_data

    def cond(t, *_):
        return t < len(net_select)

    t, _, state, event_cache, inf_data = tf.while_loop(
        cond=cond, body=body, loop_vars=(0, 0, state, event_cache, inf_data))

    return inf_data.stack(), state.net
