from .misc import hasher

import numpy as np


class SizeBatchData(object):
    """Every entry (is a list and) has a fixed number of fields. If it has to
    few, None is added if it has to many, the fields to much are forgotten
    The format is:
    data = {type: [BATCHES*]
    BATCHES = [ENTRIES*]
    ENTRIES = [field1, field2, ...]
    field i can contain any data. If it is a list, then for all entries within
    a batch the length of field i is the same (This is done, because Tensorflow
    and Theano do not allow to receive batches of data where the data is of
    different dimensions within a batch)
    """
    def __init__(self, nb_fields, descriptions=None, meta=None, prune=False,
                 hasher=hasher.identity, fprune=None):
        """

        :param nb_fields: number of fields of an entry (if None, then this has
                          to be initialized prior to the first add(...) with
                          'init_nb_field(nb_fields)'
        :param field_descriptions: List(!) of field descriptions
        :param meta:
        :param hasher: callable has to be
                        provided which is used to convert the entries in a
                        hashable format which is then used for duplicate
                        checking. The callable should convert two samples to
                        the same value if and only if they are the same for
                        you and any of them could be pruned as duplicate
        :param fprune: callable which given the item shall say if prune or not
        """
        # Attention: if adding new attributes, refactor self.copy(...)
        self.nb_fields = None
        self.descriptions = [] if descriptions is None else descriptions
        if nb_fields is not None:
            self.init_nb_fields(nb_fields)

        self.data = {}
        self.batches = {}
        self.meta = {} if meta is None else meta
        self.prune = prune
        self.hasher = hasher
        self.pruning_set = set()
        self.fprune = fprune

        self.is_finalized = False

    def init_nb_fields(self, nb_fields):
        assert self.nb_fields is None
        self.nb_fields = nb_fields
        for i in range(len(self.descriptions), self.nb_fields):
            self.descriptions.append(None)

    def _copy_fill_fields(self, other, only_structure, entry_copier):
        other.nb_fields = self.nb_fields
        other.descriptions = self.descriptions
        other.meta = self.meta
        other.prune = self.prune
        other.hasher = self.hasher
        other.fprune = self.fprune
        if not only_structure:
            other.pruning_set = set(self.pruning_set)
            other.is_finalized = self.is_finalized
            other.add_data(self, False, entry_copier=entry_copier)

    def copy(self, only_structure=False, entry_copier=None):
        """
        Return a copy of this data object. This is NOT a deep copy. For example
        changing the field_descriptions or the meta in the original
        data set also changes this value in the copied data.
        (Internal variables are deep copied, as we know how to do this)

        :param only_structure: Return instead a new SizeBatchData object, which
                               has the same structure (e.g. number of fields,
                               field descriptions, meta, and hasher) as this
                               object, but does not contain data entries.
        :param entry_copier: If given, the entries from data are not simply
                             added to this object, but, entry_copier(entry) is
                             added to this object. Use Cases: create independent
                             entries which are not effected by changing them in
                             one of the two SizeBatchData objects
        :return:
        """
        other = SizeBatchData(self.nb_fields)
        self._copy_fill_fields(other, only_structure=only_structure,
                               entry_copier=entry_copier)
        return other



    def _check_not_finalized(self):
        if self.is_finalized:
            raise TypeError("SizeBatchData does not support modifications after"
                            "it was finalized.")

    def _modify_all(self, func):
        for type in self.data:
            for idx_batch in range(len(self.data[type])):
                for idx_entry in range(len(self.data[type][idx_batch])):
                    self.data[type][idx_batch][idx_entry] = func(
                        self.data[type][idx_batch][idx_entry])

    def _over_all_types(self, func, early_stopping=False, provide_type=False):
        for type in self.data:
            if provide_type:
                r = func(self.data[type], type)
            else:
                r = func(self.data[type])
            if early_stopping and r:
                return

    def _over_all_batches(self, func, early_stopping=False, provide_type=False):
        for type in self.data:
            for idx_batch in range(len(self.data[type])):
                if provide_type:
                    r = func(self.data[type][idx_batch], type)
                else:
                    r = func(self.data[type][idx_batch])
                if early_stopping and r:
                    return


    def over_all(self, *args, **kwargs):
        self._over_all(*args, **kwargs)


    def _over_all(self, func, early_stopping=False, provide_type=False):
        for type in self.data:
            for idx_batch in range(len(self.data[type])):
                for idx_entry in range(len(self.data[type][idx_batch])):
                    if provide_type:
                        r = func(self.data[type][idx_batch][idx_entry], type)
                    else:
                        r = func(self.data[type][idx_batch][idx_entry])
                    if early_stopping and r:
                        return

    def convert_field(self, field, converter):
        def func(entry):
            entry[field] = converter(entry[field])
        self._modify_all(func)

    def get_desc(self, field):
        assert self.nb_fields is not None
        if field >= self.nb_fields:
            raise ValueError("Field description access out of bounds.")

        if field >= len(self.descriptions):
            return None
        else:
            return self.descriptions[field]

    def add(self, entry, etype=None):
        """

        :param entry:
        :param type:
        :return:
        """
        self._check_not_finalized()
        assert self.nb_fields is not None
        # Normalize entry to field number
        try:
            for i in range(len(entry), self.nb_fields):
                entry.append(None)
        except AttributeError:
            # E.g. if entry argument is a tuple
            entry = tuple(entry[i] if i < len(entry) else None for i in range(self.nb_fields))
        if len(entry) > self.nb_fields:
            entry = entry[:self.nb_fields]
        entry = tuple(entry)
        if self.has_pruning():
            if self.shall_prune(entry, add=True):
                return False

        # Get sizes for correct batch
        key = [etype]
        for i in range(self.nb_fields):
            if isinstance(entry[i], list):
                key.append(len(entry[i]))
            else:
                key.append(-1)
        t = tuple(key)
        if t not in self.batches:
            if etype not in self.data:
                self.data[etype] = []
            self.data[etype].append([])
            self.batches[t] = self.data[etype][-1]
        self.batches[t].append(entry)
        return True

    def _check_same_fields(self, data):
        assert self.nb_fields is not None
        if self.nb_fields != data.nb_fields:
            raise ValueError("Given SizeBatchData object has an incompatible"
                             " number of fields.")
        for i in range(self.nb_fields):
            if self.descriptions[i] != data.descriptions[i]:
                raise ValueError("Descriptions of fields differ in the "
                                 "SizeBatchData objects to combine.")

    def add_data(self, data, check_fields=True, entry_copier=None):
        """

        :param data: another SizeBatchData object
        :param check_fields:
        :param entry_copier: If given, the entries from data are not simply
                             added to this object, but, entry_copier(entry) is
                             added to this object. Use Cases: create independent
                             entries which are not effected by changing them in
                             one of the two SizeBatchData objects
        :return:
        """
        if check_fields:
            self._check_same_fields(data)

        if entry_copier is not None:
            def add(entry, type=None):
                self.add(entry_copier(entry), type)
            data._over_all(add, provide_type=True)
        else:
            data._over_all(self.add, provide_type=True)


    def empty(self):
        for type in self.data:
            for idx_batch in range(len(self.data[type])):
                if len(self.data[type][idx_batch]) > 0:
                    return False
        return True

    def size(self):
        def count(batch):
            count.c += len(batch)
        count.c = 0
        self._over_all_batches(count)
        return count.c

    def __len__(self):
        return self.size()

    def set_meta(self, name, value):
        self.meta[name] = value

    def has_meta(self, name):
        return name in self.meta

    def get_meta(self, name):
        return self.meta[name]

    def has_pruning(self):
        return self.prune or self.fprune is not None

    def in_pruning_set(self, item):
        return item in self.pruning_set

    def clear_pruning_set(self, make_none=False):
        if make_none:
            self.pruning_set = None
        else:
            self.pruning_set.clear()

    def shall_prune(self, item, add=False, use_fprune=True):
        """

        :param item:
        :param add: If True and the item is not pruned, adds it to the pruning set
        :return:
        """
        if self.fprune is not None and use_fprune and self.fprune(item):
            return True
        if self.prune or (self.fprune is not None and add):
            hash_entry = self.hasher(item)
        if self.prune:
            if self.in_pruning_set(hash_entry):
                return True
        # New item (as it was not pruned)
        if self.has_pruning() and add:
            self.pruning_set.add(hash_entry)
        return False


    def finalize(self, clear_pruning=True, dtype=object):
        self._check_not_finalized()
        if clear_pruning:
            self.clear_pruning_set(make_none=True)

        for type in self.data:
            for idx_batch in range(len(self.data[type])):
                if len(self.data[type][idx_batch]) == 0:
                    del self.data[type][idx_batch]
                else:
                    self.data[type][idx_batch] = np.array(self.data[type][idx_batch], dtype=dtype)
        # They are not synced anymore, thus, make batches invalid to prevent
        # other people wrongly using it


        self.batches = None
        self.is_finalized = True

    def remove_duplicates_from(self, data, hasher=None):
        """
        Removes from THIS SizeBatchData all entries which also occur in data.
        :param data: SizeBatchData object
        :param hasher
        :return:
        """
        self._check_not_finalized()
        data._check_not_finalized()
        if hasher is None:
            if self.hasher == data.hasher:
                hasher = self.hasher
        if hasher is None:
            raise ValueError("No hashing function given to compare the data"
                             " elements and both objects to not agree on a"
                             " hashing function.")
        for batch_key in self.batches:
            if batch_key in data.batches:
                # 2nd condition to satisfy that data.pruning_set was created
                if data.hasher == hasher and data.prune:
                    other_hashes = data.pruning_set
                else:
                    other_hashes = set()
                    other_batch = data.batches[batch_key]
                    for idx in range(len(other_batch)):
                        other_hashes.add(hasher(other_batch[idx]))

                my_batch = self.batches[batch_key]
                for idx in range(len(my_batch) - 1, -1, -1):
                    if hasher(my_batch[idx]) in other_hashes:
                        del my_batch[idx]



    def remove_duplicates_from_iter(self, datas, hasher=None):
        for data in datas:
            self.remove_duplicates_from(data, hasher=hasher)

    def splitoff(self, *fractions):
        """

        :param fractions: (SEQUENCE). Each element in the sequence can be a
                          fraction to split off (e.g. 0.2 to split
                          of 20% or 0.2 and 0.3 two split off 20% and 30% of the
                          data where all sets are disjunctive). The splitted data
                          is returned in new SizeBatchData objects in the same
                          order as in the *fractions sequence.
                          Alternatively elements can also be tuples of the
                          format (fraction, SizeBatchData object). For those
                          entries the split off data is directly added to
                          the object.
        :return: SizeBatchData objects containing the split off data in the same
                 order as the fractions are defined
        """
        self._check_not_finalized()
        total_size = self.size()
        split_sizes = []
        summed_split_size = 0
        split_objects = []
        sum_fractions = 0.0
        for i in range(len(fractions)):
            try:
                split_sizes.append(int(total_size * fractions[i][0]))
                sum_fractions += fractions[i][0]
                split_objects.append(fractions[i][1])
            except (TypeError, AttributeError):
                split_sizes.append(int(total_size * fractions[i]))
                sum_fractions += fractions[i]
                split_objects.append(self.copy(only_structure=True))
        if sum_fractions > 1.0:
            raise ValueError("Cannot split of more than 100% of a data set."
                             "That's not how it works.")
        for s in split_sizes:
            summed_split_size += s

        chosen = np.arange(total_size)
        np.random.shuffle(chosen)
        chosen = chosen[:summed_split_size]
        chosen = np.sort(chosen)[::-1]
        highest = total_size - 1
        lowest = None
        idx_chosen = 0
        obj_chosen = []
        for key in self.batches:
            batch = self.batches[key]
            lowest = highest - len(batch) + 1

            #Process next chosen
            while idx_chosen < len(chosen) and chosen[idx_chosen] >= lowest:
                transformed_idx = chosen[idx_chosen] - lowest
                obj_chosen.append((batch[transformed_idx], key[0]))
                del batch[transformed_idx]

                idx_chosen += 1
            # Next round
            highest = lowest - 1

        # Append entries to data objects
        np.random.shuffle(obj_chosen)
        next_lowest = 0
        for i in range(len(fractions)):
            next_highest = next_lowest + split_sizes[i]
            for idx_chosen in range(next_lowest, next_highest):
                split_objects[i].add(obj_chosen[idx_chosen][0], obj_chosen[idx_chosen][1])

            next_lowest = next_highest

        return split_objects

    def clear(self):
        self.data = {}
        self.batches = {}
        self.pruning_set = set()
        self.is_finalized = False









class SampleBatchData(SizeBatchData):
    """
    Use field_XYZ to tell the network in which field to find which information
    Annotate the fields for the current, goal, other state with the format in
    which they are given. Action shall be a string of the grounded action name,
    heuristic shall be an integer
    """
    def __init__(self, nb_fields, descriptions=None, field_names=None,
                 file=None, meta=None, prune=False, hasher=hasher.identity,
                 fprune=None, fields_from_first_entry=False):
        """
        Attention: if adding new attributes, refactor self.copy(...)


        :param nb_fields: if None, use fields_from_first_entry to set nb_fields
                          to number of fields of first entry
        :param descriptions:
        :param field_names:
        :param file:
        :param meta:
        :param hasher:
        :param fields_from_first_entry:
        """


        SizeBatchData.__init__(self, nb_fields, descriptions,
                               meta=meta, prune=prune, hasher=hasher,
                               fprune=fprune)

        self.fields_from_first_entry = fields_from_first_entry
        if self.fields_from_first_entry:
            self.fields = None
        else:
            self.fields = []
            for field_name in ([] if field_names is None else field_names):
                self._add_field(field_name)

        self.set_meta("file", file)

    def _set_field(self, idx, name):
        setattr(self, "field_" + str(name), idx)

    def _add_field(self, name):
        if name is not None:
            if name in self.fields:
                raise ValueError("Cannot add a new field which has the same name as"
                                 " an existing field: " + str(name))
            self._set_field(len(self.fields), name)
        self.fields.append(name)



    def _copy_fill_fields(self, other, only_structure, entry_copier):
        for idx_field in range(len(self.fields)):
            other._set_field(idx_field, self.fields[idx_field])

    def copy(self, only_structure=False, entry_copier=None):
        """
        Return a copy of this data object. This is NOT a deep copy. For example
        changing the field_descriptions or the meta in the original
        data set also changes this value in the copied data.
        (Internal variables are deep copied, as we know how to do this)

        :param only_structure: Return instead a new SizeBatchData object, which
                               has the same structure (e.g. number of fields,
                               field descriptions, meta, and hasher) as this
                               object, but does not contain data entries.
        :param entry_copier: If given, the entries from data are not simply
                             added to this object, but, entry_copier(entry) is
                             added to this object. Use Cases: create independent
                             entries which are not effected by changing them in
                             one of the two SizeBatchData objects
        :return:
        """
        other = SampleBatchData(self.nb_fields)
        SizeBatchData._copy_fill_fields(self, other,
                                        only_structure=only_structure,
                                        entry_copier=entry_copier)
        SampleBatchData._copy_fill_fields(self, other,
                                          only_structure=only_structure,
                                          entry_copier=entry_copier)
        return other

    def _check_same_fields(self, data):
        SizeBatchData._check_same_fields(self, data)
        if self.fields != data.fields:
            raise ValueError("The given SampleBatchData object has different"
                             "fields as the SampleBatchData object to which it"
                             "shall be added")

    def get_file(self):
        if self.has_meta("file"):
            return self.get_meta("file")
        return None

    def add(self, entry, type=None, fields=None):
        """

        :param entry: entry to add (in general a list of [field1, field2, ...])
        :param type: category to which to add the entry
        :param fields: ordered list of fields of the entry (missing fields are
                       filled with None, unknown fields are ignored
        :return:
        """
        if self.fields_from_first_entry:
            if fields is None:
                raise ValueError("SampleBatchData expected field names to be "
                                 "given with the first data entry, but failed.")
            self.fields = []
            for field_name in fields:
                self._add_field(field_name)
            if self.nb_fields is None:
                self.init_nb_fields(len(fields))

            self.fields_from_first_entry = False

        if fields is not None and len(self.fields) > 0:
            sublist = True
            idx = 0
            while sublist:
                if idx >= len(fields) or idx >= len(self.fields):
                    break
                if fields[idx] != self.fields[idx]:
                    sublist = False
                idx += 1

            if not sublist:
                map = {name: idx for (idx, name) in enumerate(fields)}
                new_entry = []
                for field_name in self.fields:
                    new_entry.append(entry[map[field_name]] if field_name in map else None)
                entry = new_entry
        return SizeBatchData.add(self, entry, type)
