import os
import glob as g
import copy
import multiprocessing as mp
from scipy import sparse
import pandas as pd
import numpy as np
import pickle
import tables as h5
import geopandas as gpd
from shapely import wkt
from scipy.spatial.distance import cdist

from utils.misc_funcs import split_list
from objects.individual import Individual, Child
from objects.household import Household
from networks.parallel_contact import SingleNetworkProcessContact


class NetworkGenerator():
    def __init__(self, site, child_age, max_dist, num_chunks_proximity, num_chunks_contact):

        self.site = site
        self.data_dir = "../data/synthetic_population/"
        self.save_dir = f"../data/networks/{site}/"

        if not os.path.isdir(self.save_dir):
            os.makedirs(self.save_dir)

        self.child_age = int(child_age)
        self.max_dist = int(max_dist)
        self.num_chunks_proximity = int(num_chunks_proximity)
        self.num_chunks_contact = int(num_chunks_contact)

        self.individuals = []
        self.households = []


    def create_networks(self):
        # create the synthetic pop
        self.create_pop()

        # create proximity network
        self.proximity_network()

        # create contact networks
        to_create_network = [self.households]
        network_str = ['houses']

        self.init_multi_process_contact_network(to_create_network,
                                                network_str)

        # combine all sub-networks into a single one (as well as h5 files)
        self.combine_output(network_str)

    def create_pop(self):
        """
        Loads the individual & household data, creates individual/household object lists
        and stores information in h5 files.
        """
        self.create_households()
        self.create_individuals()

        print(f"num individuals: {len(self.individuals)}")
        print(f"num households: {len(self.households)}")

        # Save
        FILTERS = h5.Filters(complib='zlib', complevel=5)

        filename = self.save_dir+"people.h5"

        if os.path.exists(filename):
            os.remove(filename)

        with h5.open_file(filename, mode='w', filters=FILTERS) as f:
            save = f.create_vlarray(f.root, 'people', h5.ObjectAtom())
            for i in self.individuals:
                save.append(i.__dict__)

            save_hh = f.create_vlarray(f.root, 'households', h5.ObjectAtom())
            for h in self.households:
                save_hh.append(h.__dict__)

    def create_households(self):
        # Load the data & convert to geodataframe
        df = pd.read_csv(self.data_dir + f"synthetic_{self.site}_households.csv")
        df.sort_values('new_hhid', inplace=True)

        df['geometry'] = df['geometry'].apply(wkt.loads)
        gdf = gpd.GeoDataFrame(df, geometry='geometry')

        # Loop through dataframe & create objects
        self.households = []

        for i, row in gdf.iterrows():
            new_household = Household(row['new_hhid'], row['geometry'].y, row['geometry'].x)

            self.households.append(new_household)
            if (self.households.index(new_household)) != new_household.id:
                print(new_household.id, self.households.index(new_household))
                raise Exception("mismatch of ids")

    def create_individuals(self):
        # Load the data & convert to geodataframe
        df = pd.read_csv(self.data_dir + f"synthetic_{self.site}_people.csv")
        df['geometry'] = df['geometry'].apply(wkt.loads)
        gdf = gpd.GeoDataFrame(df, geometry='geometry')

        # Loop through dataframe & create objects
        self.individuals = []

        for i, row in gdf.iterrows():
            if row['age'] <= self.child_age:
                new_individual = Child(i, row['ip_sex'], row['age'], row['new_hhid'])
            else:
                new_individual = Individual(i, row['ip_sex'], row['age'], row['new_hhid'])

            # and add to the household occupants list
            self.households[new_individual.household].occupants.append(new_individual.id)

            self.individuals.append(new_individual)

    def proximity_network(self):
        """Create the proximity network: split into a number of chunks & perform
        each chunk iteratively. N.B. not parallelised, as splitting out to different
        processes doesn't necessarily resolve the memory issues."""

        # create arrays that capture the ids & lat/lon of each individual
        np_array = np.zeros((len(self.individuals), 2))
        pop_array = np.array([[i.id, 0., 0.] for i in self.individuals])
        for i in self.individuals:
            np_array[i.id, 0] = self.households[i.household].lat
            np_array[i.id, 1] = self.households[i.household].lon

            pop_array[i.id, 1] = self.households[i.household].lat
            pop_array[i.id, 2] = self.households[i.household].lon

        # split pop array into a number of chunks, recording the size of each chunk
        pop_chunks = split_list(pop_array, self.num_chunks_proximity)
        chunk_sizes = [len(i) for i in pop_chunks]

        # loop through each chunk & calc proximity matrix, saving to a temporary output file
        start_indexes = []
        temp = 0
        for i, start_index in enumerate(chunk_sizes):
            start_indexes.append(temp)
            temp += start_index

            chunk_xys = np.array([[0,0] for j in pop_chunks[i]])
            for j in pop_chunks[i]:
                chunk_xys[int(j[0])-temp][0] = j[1]
                chunk_xys[int(j[0])-temp][1] = j[2]

            data = cdist(chunk_xys, np_array)
            data[data > self.max_dist] = 0
            links = np.nonzero(data)
            data = sparse.csr_matrix(([data[links[0][j], links[1][j]] for j, _ in enumerate(links[0])],
                                     ([(temp-start_index)+l for l in links[0]], links[1])),
                                     shape=(len(np_array), len(np_array)))

            filename_ = (f"{self.save_dir}proximity_{i}.npz")

            if os.path.exists(filename_):
                os.remove(filename_)

            sparse.save_npz(filename_, data)

    def init_multi_process_contact_network(self, resources, resource_str):
        """
        Set up processors, trigger individual sub-processes initialisation,
        and compile date when all processes are complete.
        """
        all_procs = [mp.Process(name='Process {0}'.format(id_),
                                target=self.init_individual_contact_process,
                                args=(r,
                                      resource_str[id_]))
                     for id_, r in enumerate(resources)]

        for proc in all_procs:
            proc.start()

        for proc in all_procs:
            proc.join()
            proc.terminate()

    def init_individual_contact_process(self, objects_to_network, network_str):
        """
        Initialise a single process, returning results to a multiprocessor
        queue for compiling.
        """
        SingleNetworkProcessContact(objects_to_network, network_str,
                                    self.individuals, self.households, 
                                    self.save_dir, self.num_chunks_contact)



    def combine_output(self, contact_network_str):
        """
        Combine all output data to single master HDF5 file
        (agent data)/npz files (network data).
        """

        FILTERS = h5.Filters(complib='zlib', complevel=5)

        master_filename = self.save_dir+f"master_pop.h5"
        if os.path.exists(master_filename):
            os.remove(master_filename)

        # agent data
        individual_file = self.save_dir+f"people.h5"
        households_file = self.save_dir+f"houses.h5"

        self.copy_h5(individual_file, master_filename, 'demography')
        self.copy_h5(households_file, master_filename, 'networks')

        # network data
        proximity_npzs = [self.save_dir+f"proximity_{i}.npz" for i in range(self.num_chunks_proximity)]
        self.append_npz_arrays(proximity_npzs, self.save_dir+f"proximity.npz")

        for network_type in contact_network_str:
            contacts_npzs = [self.save_dir+f"contacts_{network_type}_{i}.npz" for i in range(self.num_chunks_contact)]
            self.append_npz_arrays(contacts_npzs, self.save_dir+f"{network_type}_net.npz")

    def copy_h5(self, from_filename, to_filename, to_group):
        """
        Copy all data from specific filename to a group in another file.
        """
        with h5.open_file(to_filename, mode='a') as h5fw:
            if '/'+to_group not in h5fw:
                group = h5fw.create_group(h5fw.root, to_group)
            for h5name in g.glob(from_filename):
                h5fr = h5.open_file(h5name, mode='r')
                h5fr.root._f_copy_children(h5fw.root['/'+to_group],
                                           recursive=True)
                h5fr.close()
        os.remove(from_filename)

    def append_npz_arrays(self, input_filenames, master_filename):
        """
        Append list of files to single, master file. Written for sparse
        matrices with data in the form of three matrices - data, row, col.
        """
        master_sparse = sparse.load_npz(input_filenames[0])
        for i in input_filenames[1:]:
            j = sparse.load_npz(i)
            master_sparse += j

        sparse.save_npz(master_filename, master_sparse)
        
        for file in input_filenames:
            os.remove(file)

        print(f"{master_filename} dimension: {master_sparse.shape}")
        print(f"{master_filename} max/min: {master_sparse.max(), master_sparse.min()}")
