from __future__ import print_function, division

from ...misc.binary_sum_tree import BinarySumTree

import keras
import math
import numpy as np
import random
import sys
import time

if sys.version_info < (3,):
    import Queue as queue
else:
    import queue


class KerasDataGenerator(keras.utils.Sequence):
    """Generates data for Keras"""
    def __init__(self, data, x_fields=None, y_fields=None,
                 types=None, batch_size=100,
                 x_converter=None, y_converter=None, y_remember=None,
                 shuffle=True, count_diff_samples=None,
                 similarity=None, class_weights=None):
        """

        :param data: Single or list of SizeBatchData object(s)
        :param x_fields: Fields in SizeBatchData entries to use as x data
                         if not given, uses all fields which are not used
                         for y values
        :param y_fields: fields in SizeBatchData entries to use as y data
                         if not given, uses the last field as y value
        :param types: types from the SizeBatchData to use
        :param batch_size:
        :param x_converter:
        :param y_converter:
        :param y_remember: List object. If given, all y values for which data
                           was generated is appended (in the batches they
                           are generated, thus, at the end this is a list of
                           numpy arrays)(any object with an
                           append method can be used). Example use case is, if
                           a network evaluation only returns predictions, but
                           not the original values
        :param shuffle: shuffle data after each epoch
        :param count_diff_samples: If None it does not keep track of the
                                   different samples generated. Otherwise this
                                   is a callable which expects two arguments
                                   (X and Y data of sample) and generates a
                                   hash for them.
        :param similarity: If None it does not measure the similarity between an
                           generated value X and all entries of an iterable of
                           SizeBatchData (which would be given in similarity).
                           Otherwise similarity is a triple:
                               -list to store the similarities in batches,
                               -iterable of SizeBatchData to compare to,
                               -method to calculate the similarity with the
                                interface (sample, iterable of SizeBatchData)
                           Using this feature will cost a lot of performance!
        """
        types = ["init", "inter"] if types is None else types
        self.data = data if isinstance(data, list) else [data]
        self.y_fields = ((data[0].nb_fields - 1) if y_fields is None
                         else y_fields)
        self.x_fields = x_fields
        if self.x_fields is None:
            # Used to test if integer in y_fields
            def _y_tester_in(x):
                return x in self.y_fields
            y_tester = _y_tester_in
            try:
                y_tester(3)
            except TypeError:
                def _y_tester_eq(x):
                    return x == self.y_fields
                y_tester = _y_tester_eq

            self.x_fields = []
            for i in range(data[0].nb_fields):
                if not y_tester(i):
                    self.x_fields.append(i)
        self.x_converter = x_converter
        self.y_converter = y_converter
        self.y_remember = y_remember
        self.similarity = similarity

        self.types = types
        self.batch_size = batch_size
        self.batch_order = []
        for idx_ds in range(len(self.data)):
            ds = self.data[idx_ds]
            for sample_type in types:
                if sample_type not in ds.data:
                    continue
                for idx_batch in range(len(ds.data[sample_type])):
                    batch = ds.data[sample_type][idx_batch]
                    count = int(math.ceil(len(batch)/float(batch_size)))
                    step = 0 if count == 0 else int(len(batch)/count)
                    start = 0
                    for i in range(count):
                        if i == count - 1:
                            self.batch_order.append((
                                idx_ds, sample_type, idx_batch, start,
                                len(batch)))
                        else:
                            self.batch_order.append((
                                idx_ds, sample_type, idx_batch, start,
                                start + step))
                        start += step
        self._next = -1

        self.shuffle = shuffle
        self.count_diff_samples = count_diff_samples
        self._mem_count_diff_samples = count_diff_samples
        self.generated_sample_hashes = set()

        self.class_weights = class_weights

    def reset(self):
        """
        Reset this Generator to be used anew
        :return:
        """
        self.count_diff_samples = self._mem_count_diff_samples
        self.generated_sample_hashes = set()
        if self.y_remember is not None:
            self.y_remember.clear()

    def __len__(self):
        """Denotes the number of batches per epoch"""
        return len(self.batch_order)

    def __getitem__(self, index):
        """Generate one batch of data"""
        index = self.batch_order[index]
        entries = self.data[index[0]].data[index[1]][index[2]][
                  index[3]:index[4]]
        x = entries[:, self.x_fields]
        y = entries[:, self.y_fields]

        if self.count_diff_samples is not None:
            for i in range(len(x)):
                self.generated_sample_hashes.add(
                    self.count_diff_samples(x[i], y[i]))

        if self.similarity is not None:
            (sim_batches, data_sets, measure) = self.similarity
            ary = np.ndarray(shape=(len(entries),), dtype=float)
            for i in range(len(x)):
                ary[i] = measure(entries[i, :], data_sets)
            sim_batches.append(ary)

        sample_weights = None
        if self.class_weights is not None:
            sample_weights = np.array([self.class_weights[yy[0]] for yy in y])
        if self.x_converter is not None:
            x = self.x_converter(x)
        if self.y_converter is not None:
            y = self.y_converter(y)
        if self.y_remember is not None:
            self.y_remember.append(y)

        if sample_weights is None:
            return x, y
        else:
            return x, y, sample_weights

    def on_epoch_end(self):
        if self.y_remember is not None:
            print("BATCH END:")
        'Updates indexes after each epoch'
        self.count_diff_samples = None  # All different samples have been seen
        if self.shuffle:
            for ds in self.data:
                for sample_type in self.types:
                    if sample_type not in ds.data:
                        continue
                    for batch in ds.data[sample_type]:
                        np.random.shuffle(batch)
            random.shuffle(self.batch_order)


class QueueDataGenerator(object):
    batch_size = property(lambda self: self._batch_size)
    samples_load = property(lambda self: self._samples_load)
    samples_requested = property(lambda self: self._samples_requested)

    def __init__(self, in_queue, batch_size,
                 fetch_timeout=None, time_limit=None):
        self._queue = in_queue
        self._batch_size = batch_size
        self._block = True
        self._first_fetch = True

        self._max_time = (None if time_limit is None
                          else time.time() + time_limit)
        self._fetch_timeout = fetch_timeout

        self._samples_batch = 0
        self._next_batch = None

        self._samples_requested = 0
        self._samples_load = 0

    def __iter__(self):
        return self

    def _fetch_new_data_from_queue(self):
        try:
            max_time = (None if self._max_time is None else
                        (self._max_time - time.time()))
            if self._fetch_timeout is None:
                timeout = max_time
            elif max_time is None:
                timeout = self._fetch_timeout
            else:
                timeout = min(max_time, self._fetch_timeout)
            data = self._queue.get(block=self._block or self._first_fetch,
                                   timeout=timeout)
            self._samples_load += len(data[0])
            self._first_fetch = False
            return data
        except queue.Empty:
            return None

    def __next__(self):
        while self._samples_batch < self._batch_size:
            next_data = self._fetch_new_data_from_queue()
            if next_data is None:  # Hit timeout, report back to perform unblock
                return None

            assert all(len(next_data[0]) == len(elem) for elem in next_data), \
                "Not matching batch sizes: %s" % ", ".join(
                    [str(len(elem)) for elem in next_data])
            if self._next_batch is None:
                self._next_batch = [[] for _ in next_data]
            for no, elem in enumerate(next_data):
                self._next_batch[no].append(elem)
            self._samples_batch += len(elem)

        self._next_batch = [np.concatenate(elem, axis=0)
                            for elem in self._next_batch]

        next_data = [elem[:self._batch_size] for elem in self._next_batch]
        self._next_batch = [[] if self._batch_size == len(elem) else
                            [elem[self._batch_size]]
                            for elem in self._next_batch]
        self._samples_batch -= self._batch_size

        self._samples_requested += self._batch_size
        return next_data

    def next(self):
        return self.__next__()

    def has_importance_weights(self):
        return False

    def has_priorities(self):
        return False


class BufferedQueueDataGenerator(QueueDataGenerator):
    max_buffer_size = property(lambda self: self._max_buffer_size)

    def __init__(self, in_queue, batch_size, buffer_size,
                 fetch_timeout=None, time_limit=None):
        QueueDataGenerator.__init__(self, in_queue, batch_size,
                                    fetch_timeout=fetch_timeout,
                                    time_limit=time_limit)
        self._max_buffer_size = buffer_size
        self._buffer = self._get_new_buffer()

        self._block = False

        self._shapes = None
        self._dtypes = None

    def _initialize_shapes_and_dtypes(self, new_data):
        self._shapes = [None if elem is None else
                        ([self.batch_size] + list(elem.shape[1:]))
                        for elem in new_data]
        self._dtypes = [None if elem is None else elem.dtype
                        for elem in new_data]

    def _get_new_buffer(self):
        raise NotImplementedError("To implement")

    def _add_new_data_into_buffer(self, new_data):
        raise NotImplementedError("To implement")

    def _get_next_batch(self):
        raise NotImplementedError("To implement")


    def __next__(self):
        new_data = self._fetch_new_data_from_queue()
        if self._shapes is None:
            if new_data is None:
                return None
            self._initialize_shapes_and_dtypes(new_data)
        self._add_new_data_into_buffer(new_data)
        self._samples_requested += self.batch_size
        return self._get_next_batch()


class ReplayQueueDataGenerator(BufferedQueueDataGenerator):
    curr_buffer_size = property(lambda self: self._curr_buffer_size)

    def __init__(self, in_queue, batch_size, buffer_size,
                 fetch_timeout=None, time_limit=None):
        BufferedQueueDataGenerator.__init__(
            self, in_queue, batch_size, buffer_size,
            fetch_timeout=fetch_timeout,
            time_limit=time_limit)
        self._curr_buffer_size = 0
        self._buffer = self._get_new_buffer()

    def _get_new_buffer(self):
        return []

    def _add_new_data_into_buffer(self, new_data):
        if new_data is None or len(new_data[0]) == 0:
            return
        self._buffer.append(new_data)
        self._curr_buffer_size += len(new_data[0])
        diff = self.curr_buffer_size - self.max_buffer_size
        # Delete diff many samples (which overfill the buffer)
        if diff > 0:
            samples_in_blocks = 0
            # idx shows to the last block which is not completely deleted
            for idx, block in enumerate(self._buffer):
                if samples_in_blocks + len(block[0]) > diff:
                    break
                samples_in_blocks += len(block[0])

            # The data at buffer[idx] cannot be empty!
            assert samples_in_blocks + len(block[0]) >= diff, \
                "Selected too few samples: %i+%i/%i %i" % (
                    samples_in_blocks, len(block[0]), diff, idx)

            # Delete (forget) whole blocks
            if samples_in_blocks > 0:
                self._buffer = self._buffer[idx:]
                self._curr_buffer_size -= samples_in_blocks

            # Removing first samples from next block
            # It is guaranteed that the next block has strictly more samples
            # than we need to delete
            diff -= samples_in_blocks
            if diff > 0:
                assert len(self._buffer[0][0]) > diff, \
                    "buffer smaller than expected: %i/%i" % (
                        len(self._buffer[0][0]), diff)
                self._buffer[0] = tuple(
                    None if elem is None else elem[diff:]
                    for elem in self._buffer[0])

                self._curr_buffer_size -= diff
            assert self.curr_buffer_size == self.max_buffer_size, \
                "buffer larger than allowed and expected"
            assert (sum(len(x[0]) for x in self._buffer) ==
                    self.curr_buffer_size), "Invalid curr buff count: %i/%i" % (
                sum(len(x[0]) for x in self._buffer), self.curr_buffer_size)

    def _get_next_batch(self):
        if self._curr_buffer_size < self._batch_size:
            return None

        # Every sample can consist of multiple data elements (Xi..., Yj...)
        # [[BatchSizeXDimensions] for each data element]
        next_batch = [None if shape is None else np.empty(shape, dtype=dtype)
                      for shape, dtype in zip(self._shapes, self._dtypes)]

        # indices of the the chosen samples
        chosen = np.array(sorted(random.sample(range(0, self.curr_buffer_size),
                                               self._batch_size)))

        shift_index = 0  # shift used to select from one buffer list its samples
        final_sample_index = 0  # index of last sample in a buffer list
        final_chosen_index = 0  # last index for 'chosen' describing an element
                                # to select from the current buffer list

        for buffer_index in range(len(self._buffer)):
            final_sample_index += len(self._buffer[buffer_index][0])
            first_chosen_index = final_chosen_index

            # determine last sample from the current buffer list to add
            while (final_chosen_index < self._batch_size and
                   chosen[final_chosen_index] < final_sample_index):
                final_chosen_index += 1

            # write for each data element the samples into the next_batch
            for no, buf in enumerate(self._buffer[buffer_index]):
                if next_batch[no] is None:
                    continue
                next_batch[no][first_chosen_index:final_chosen_index] = (
                    buf[chosen[first_chosen_index:final_chosen_index] -
                        shift_index]
                )

            shift_index += len(self._buffer[buffer_index][0])
            if final_chosen_index == self._batch_size:
                break
        return next_batch


MAX_PRIORITY = float(10**6)


class PrioritizedReplayQueueDataGenerator(BufferedQueueDataGenerator):
    curr_buffer_size = property(lambda self: len(self._buffer))

    def __init__(self, in_queue, batch_size, buffer_size,
                 beta_generator=None, fetch_timeout=None, time_limit=None):
        BufferedQueueDataGenerator.__init__(
            self, in_queue, batch_size, buffer_size,
            fetch_timeout=fetch_timeout, time_limit=time_limit)
        self._beta_generator = beta_generator
        self._last_batch = None

    def _get_new_buffer(self):
        return BinarySumTree(size_limit=self.max_buffer_size)

    def _add_new_data_into_buffer(self, new_data):
        if new_data is None:
            return
        for idx in range(len(new_data[0])):
            self._buffer.add_node(value=MAX_PRIORITY,
                                  data=[None if x is None else x[idx]
                                        for x in new_data])
        assert self.curr_buffer_size <= self.max_buffer_size, \
            "buffer larger than allowed and expected"

    def _get_next_batch(self):
        if self.curr_buffer_size < self.batch_size:
            return None

        interval_size = self._buffer.total_priority / self.batch_size
        priorities = [
            random.uniform(i * interval_size, (i + 1) * interval_size)
            for i in range(self.batch_size)]
        while priorities[-1] == self._buffer.total_priority:
            priorities[-1] = random.uniform(
                self._buffer.total_priority - interval_size,
                self._buffer.total_priority)

        nodes = [self._buffer.lookup_value(p) for p in priorities]
        self._last_batch = nodes
        data = [None if self._dtypes[idx] is None else
                np.stack([n.data[idx] for n in nodes])
                for idx in range(len(self._dtypes))]

        if self._beta_generator is not None:
            beta = next(self._beta_generator)
            iw = [(1 / (len(nodes) * n.value)) ** beta
                  for n in nodes]
            return data, np.array(iw)
        else:
            return data

    def has_importance_weights(self):
        return self._beta_generator is not None

    def has_priorities(self):
        return True

    def update_priorities(self, losses):
        assert self._last_batch is not None, \
            "no last batch to update priorities"
        # Assumption: If losses is multiple of last_batch length, then last
        # batch is trained on multiple times (the i-th element in last batch
        # is on location x * len(batch) + i in the losses.
        assert len(losses) % len(self._last_batch) == 0, \
            "%i, %i" % (len(losses), len(self._last_batch))
        if len(losses) > len(self._last_batch):  # Average loss
            factor = len(losses) / len(self._last_batch)
            losses = losses.reshape(int(factor), int(len(losses)/factor)).mean(axis=0)
        for n, l in zip(self._last_batch, losses):
            n.change_value(l)
        self._last_batch = None




import tensorflow as tf
from keras import backend
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
# from tensorflow.tools.graph_transforms import TransformGraph


def store_keras_model_as_protobuf(
        model, directory=".", file_model="network.pb",
        quantize=False, theano=False, num_outputs=1, prefix_outputs="",
        show_outputs=False, store_graphdef=None):
    """
    Copyright (c) 2017, by the Authors: Amir H. Abdi
    This software is freely available under the MIT Public License.

    Converts a
    :param model:
    :param directory:
    :param file_model:
    :param quantize:
    :param theano:
    :param num_outputs:
    :param prefix_outputs:
    :param show_outputs:
    :param store_graphdef:
    :return:
    """
    # Prepare variables
    if directory == "" or directory is None:
        directory = "."

    previous_learning_phase = backend.learning_phase()
    if isinstance(previous_learning_phase, tf.Tensor):
        # HACK!!!
        previous_learning_phase = int(previous_learning_phase.name[-1])
    backend.set_learning_phase(0)

    if theano:
        backend.set_image_data_format('channels_first')
    else:
        backend.set_image_data_format('channels_last')

    # Set node names in computation graph
    pred_node_names = ["%s%i" % (prefix_outputs, i) for i in range(num_outputs)]
    _ = [tf.identity(model.outputs[i], name=pred_node_names[i])
         for i, pred_node_name in enumerate(pred_node_names)]
    if show_outputs:
        print('Output nodes names are: ', pred_node_names)

    # Store readable GraphDef
    sess = backend.get_session()
    if store_graphdef is not None:
        tf.train.write_graph(sess.graph.as_graph_def(),
                             directory, store_graphdef, as_text=True)

    # Store Protobuf
    if quantize:
        assert False, "This feature is currently disabled."
        # # Worked for tensorflow 1.X, but not updated for >1.X
        # transforms = ["quantize_weights", "quantize_nodes"]
        # transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [],
        #                                        pred_node_names, transforms)
        # constant_graph = graph_util.convert_variables_to_constants(
        #     sess, transformed_graph_def, pred_node_names)
    else:
        constant_graph = graph_util.convert_variables_to_constants(
            sess, sess.graph.as_graph_def(), pred_node_names)
    graph_io.write_graph(constant_graph, directory, file_model, as_text=False)

    backend.set_learning_phase(previous_learning_phase)
