#!/usr/bin/env python

"""Single process contact networks"""

import os
import itertools
import tables as h5
import networkx as nx
from scipy import sparse

from utils.misc_funcs import split_list


class SingleNetworkProcessContact:
    """
    Details for the model run on a subset of the array.
    """
    def __init__(self, object_to_network,
                 network_str, individuals,
                 households, save_dir,
                 num_chunks):
        """
        Base attributes for an instance of the SingleProcess class.
        """
        self.individuals = individuals
        self.households = households
        self.save_dir = save_dir
        self.contact_network(object_to_network, network_str, num_chunks)

    def contact_network(self, locations, location_identifier, number_of_dumps):
        """
        Created individual agents, based on imported data. Record contacts as
        x1,x2,y1,y2 for network edges.
        """
        graph = nx.Graph()
        for i in self.individuals:
            graph.add_node(i.id, loc=(self.households[i.household].lat,
                                      self.households[i.household].lon))
        chunks = split_list(locations, number_of_dumps)

        for i, chunk in enumerate(chunks):
            sparse = self.chunk(graph, chunk, location_identifier)
            edges = list(graph.edges.keys())
            graph.remove_edges_from(edges)
            self.handle_network_output(sparse, location_identifier, i)
        self.handle_object_output(locations, location_identifier)

    def chunk(self, graph, locations, location_identifier):
        """
        Calculate network for subset of resource locations.
        """
        if location_identifier == 'houses':

            for loc in locations:
                pairs = itertools.permutations(loc.occupants, 2)
                for pair in pairs:
                    graph.add_edge(pair[0], pair[1])
            sparse = nx.to_scipy_sparse_matrix(graph, format='csr')

        else:

            for loc in locations:
                pairs = itertools.permutations(loc.users, 2)
                for pair in pairs:
                    graph.add_edge(pair[0], pair[1])
            sparse = nx.to_scipy_sparse_matrix(graph, format='csr')

        return sparse

    def handle_network_output(self, result, id_, iter_):
        """
        Save output chunk to temp file.
        """
        filename_ = self.save_dir+f"contacts_{id_}_{iter_}.npz"

        if os.path.exists(filename_):
            os.remove(filename_)

        sparse.save_npz(filename_, result)

    def handle_object_output(self, locations, id_):
        """
        Save output chunk to temp file (objects).
        """
        FILTERS = h5.Filters(complib='zlib', complevel=5)
        filename_ = self.save_dir+f"{id_}.h5"

        if os.path.exists(filename_):
            os.remove(filename_)

        with h5.open_file(filename_, mode='w', filters=FILTERS) as f:
            loc_save = f.create_vlarray(f.root, id_, h5.ObjectAtom())
            for l in locations:
                loc_save.append(l.__dict__)
