# -*- coding: utf-8 -*-
"""
Created on Sun Aug 18 16:39:56 2024

@author: David
"""
from tqdm import tqdm
import numpy as np
import pathlib
import os
from collections import defaultdict, Counter
import tensorflow as tf
from tensorflow.io import TFRecordWriter, serialize_tensor, parse_tensor
from contextlib import ExitStack
import itertools
from keras.preprocessing.image import ImageDataGenerator
from tensorflow import keras
import gc
import math
import random
from numpy.random import Generator, PCG64



def crear_patches_generator(cobertura, auxiliares_normalizados, patch_size=5):
    """
    Generador que devuelve individualmente cada patch calculado en lugar de 
    cargalos enteros en memoria.
    
    Argumentos
    ----------
    cobertura: numpy.ndarray
    auxiliares: numpy.ndarray
    patch_size: int
    """
    half_patch = patch_size // 2

    # Configurar las barras de progreso
    total_patches = (cobertura.shape[0] - patch_size + 1) * (cobertura.shape[1] - patch_size + 1)
    progress_bar = tqdm(total=total_patches, desc="Creando parches")

    for i in range(half_patch, cobertura.shape[0] - half_patch):
        for j in range(half_patch, cobertura.shape[1] - half_patch):
            # Crear el parche de la cobertura terrestre
            patch_cobertura = cobertura[i - half_patch: i + half_patch + 1,
                                        j - half_patch: j + half_patch + 1]
            if np.isnan(patch_cobertura).any():
                progress_bar.update(1)
                continue

            # Crear los parches de los auxiliares
            patch_aux = []
            for k in range(auxiliares_normalizados.shape[-1]):
                aux_patch = auxiliares_normalizados[i - half_patch: i + half_patch + 1,
                                       j - half_patch: j + half_patch + 1, k]
                if np.isnan(aux_patch).any():
                    patch_aux = None
                    break
                patch_aux.append(aux_patch)
            
            if patch_aux is not None:
                yield patch_cobertura, np.stack(patch_aux, axis=-1), cobertura[i, j]
            progress_bar.update(1)
    
    progress_bar.close()
    
    
def organizar_secuencias_temporales_generator(patches, timesteps=3,
                                              out_path = "sequences"):
    """
    Guarda en fichero las secuencias temporales obtenidas a partir de los parches.

    Parameters
    ----------
    patches : np.ndarray
        Parches de las imágenes, ordenados temporalemente
    timesteps : int, optional
        Número de parches de fechas contiguas que agrupar en la secuencia.
    out_path : str, optional
        Directorio donde guardar los secuencias producidas

    Returns
    -------
    files : List[str]
        Nombre de los ficheros donde se guardan las secuencias
    class_counts : Dict[Number, int]
        Diccionario cuyas claves son las clases y los valores su frecuencia
    total_written : int
        Número total de elementos (secuencias) escritos

    """
    pathlib.Path(out_path).mkdir(exist_ok = True, parents = True)
    
    files = [os.path.join(out_path, "sequences_{}.tfrecords".format(i)) for i in range(len(patches) - timesteps)]
    class_counts = defaultdict(float)
    total_written = 0
    writers = [TFRecordWriter(file_path) for file_path in files]
    with ExitStack() as estack:
        for w in writers:
            estack.enter_context(w)
        while True:
            try:
                seq = [next(patches[t]) for t in range(len(patches))]
                X_seq = [elem[0] for elem in seq]
                aux_seq = [elem[1] for elem in seq]
                y_seq = [elem[2] for elem in seq]
                for j in range(len(patches) - timesteps):
                    total_written += 1
                    X_seq_j = np.array([X_seq[j + t] for t in range(timesteps)])
                    aux_seq_j = np.array([aux_seq[j + t] for t in range(timesteps)])
                    y_seq_j = np.array(y_seq[j + timesteps])
                    class_counts[y_seq[j + timesteps]] += 1
                    escribir_ejemplo_stacked(writers[j], X_seq_j.reshape(X_seq_j.shape + (1,)), aux_seq_j, y_seq_j)
            except StopIteration:
                break
    return files, class_counts, total_written



def escribir_ejemplo_stacked(tf_writer, x, aux, y):
    """
    Escribe un solo ejemplo con un writer de tensorflow

    Parameters
    ----------
    tf_writer : tensorflow.io.TFRecordWriter
    x : numpy.ndarray
        Dato (secuencia)
    aux: numpy.ndarray
        Dato auxiliar (secuencia)ç
    y: Number
        Etiqueta

    Returns
    -------
    None.

    """
    # Se guardan los datos
    record_bytes = tf.train.Example(
        features=tf.train.Features(
            feature={
            "x": tf.train.Feature(
                bytes_list=tf.train.BytesList(value=[serialize_tensor(x).numpy()])
            ),
            "aux": tf.train.Feature(
                bytes_list=tf.train.BytesList(value=[serialize_tensor(aux).numpy()])
            ),
            "y": tf.train.Feature(
                bytes_list=tf.train.BytesList(value=[serialize_tensor(y).numpy()])
            )
            }
        )
    ).SerializeToString()
    tf_writer.write(record_bytes)


def batched(iterable, n):
    """
    Devuelve secuencias de longitud n obtenidas de iterable
    
    Argumentos
    ----------
    iterable: Iterable
        El iterable del que obtener las secuencias
    n: int
        La longitud de las secuencias
    """
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(itertools.islice(iterator, n)):
        yield batch

def mix_patches(patch_list, batch_size = 32):
    """
    Construye secuencias temporales de x, aux e y a partir de una lista
    cuyos elementos son tuplas con el formato (x, aux, y)
    """
    for combined_batch in batched(zip(*patch_list), batch_size):
        combined_batch = list(combined_batch)
        yield (
            (np.array([np.expand_dims([elem[0] for elem in combined], axis = -1) for combined in combined_batch]),
            np.array([[elem[1] for elem in combined] for combined in combined_batch])),
            np.array([[elem[2] for elem in combined][-1] for combined in combined_batch])
        )


def load_file_stacked(file_path):
    """
    Carga ejemplos de un archivo en forma de  tf.data.Dataset.

    Parameters
    ----------
    file_path : str
        La ruta del archivo.

    Returns
    -------
    tf.data.Dataset
        Los datos del fichero

    """
    def decode_example(record_bytes):
      return tf.io.parse_single_example(
              # Datos
              record_bytes,
    
              # Schema
              {"x": tf.io.FixedLenFeature([], dtype=tf.string, default_value = ""),
               "aux": tf.io.FixedLenFeature([], dtype=tf.string, default_value = ""),
               "y": tf.io.FixedLenFeature([], dtype=tf.string, default_value = "")}
          )

    def return_decoded_tuple(dict_vals):
        # Se deshace la transformación a array monodimensional para obtener
        # los datos auxiliares, originales y la 
        
        x = parse_tensor(dict_vals["x"], out_type = tf.float32)
        aux = parse_tensor(dict_vals["aux"], out_type = tf.float32)
        y = parse_tensor(dict_vals["y"], out_type = tf.float32)
        
        return x, aux, y
    
    # Es necesario aplicar las transformaciones al conjunto de datos,
    # para que devuelva los datos procesados
    return tf.data.TFRecordDataset([file_path]).map(decode_example).map(return_decoded_tuple)


def augmentar_secuencia(secuencia, seed = 42):
    datagen = ImageDataGenerator(rotation_range=10, width_shift_range=0.1, height_shift_range=0.1)
    secuencia_aumentada = []
    for i in range(secuencia.shape[0]):
        imagen = secuencia[i, :, :, :].numpy()  # Extraer cada imagen en la secuencia
        imagen = imagen.reshape((1,) + imagen.shape)  # Ajustar la forma para el generador
        #iterador = datagen.flow(imagen, batch_size=1, seed = seed)
        iterador = datagen.flow(imagen, batch_size=1)
        secuencia_aumentada.append(iterador[0].reshape(imagen.shape[1:]))  # Aplicar augmentación
    return np.array(secuencia_aumentada)


def ajustar_division_clases_proporcion_stream_full_random(files, total_size, test_proportion=0.2,
                                              min_samples_per_class=5, 
                                              out_dir = "ordered_sequences",
                                              target_size_undersample = 1000000,
                                              factor_oversample = 10,
                                              seed = 42,
                                              num_shards = 100): 
    """
    Guarda los datos de entrenamiento y testeo, aplicando la corrección del 
    número de elementos de test y el oversampling, y distribuyéndolos
    aleatoriamente en diferentes ficheros

    Parameters
    ----------
    files : List[str]
        Ficheros donde se guardaron las secuencias originales
    total_size : int
        Número total de elementos guardados en ficheros
    test_proportion : float, optional
        Proporción de datos de test. The default is 0.2.
    min_samples_per_class : int, optional
        Número mínimo de elementos de cada clase. The default is 5.
    out_dir : str, optional
        Directorio donde se guardan los datos de entrenamiento y test.
        The default is "ordered_sequences".
    target_size_undersample : int, optional
        Número máximo de elementos de la clase mayoritaria.
        The default is 1000000.
    factor_oversample : int, optional
        Factor de aumento de elementos de las clases minoritarias.
        The default is 10.
    seed : int, optional
        Semilla para el generador de número pseudoaleatorio. The default is 42.
    num_shards : int, optional
        Número de archivos en los que guardar los datos de entrenamiento.
        The default is 100.

    Returns
    -------
    train_paths : List[str]
        Lista de archivos en los que se guardan los datos de entrenamiento
    test_paths : List[str]
        Número de archivos en los que se guardan los datos de test
    class_counter_train : Dict[Number, int]
        Número de elementos de cada clase en datos de entrenamiento
    class_counter_test : Dict[Number, int]
        Número de elementos de cada clase en datos de test

    """
    pathlib.Path(out_dir).mkdir(parents = True, exist_ok = True)
    data = itertools.chain(*[load_file_stacked(f) for f in files])
    train_size = int(total_size * (1 - test_proportion))
    class_counter_train = Counter()
    class_counter_test = Counter()
    class_counter_train.update((next(data)[-1].numpy() for _ in range(train_size)))
    class_counter_test.update((x[-1].numpy() for x in data))
    
    majority_class = max(class_counter_train, key = lambda x: class_counter_train[x])
    
    class_move_dict = dict()
    for clase in class_counter_train.keys() | class_counter_test.keys():
        num_to_move = max(int(class_counter_train[clase] * test_proportion), min_samples_per_class)
        class_move_dict[clase] = min(num_to_move, class_counter_train[clase] - 1)
    
    data = itertools.chain(*[load_file_stacked(f) for f in files])
    
    files_out = [os.path.join(out_dir, "train_shard_{}.tfrecords".format(i)) for i in range(num_shards)]
    writers = [TFRecordWriter(path) for path in files_out]
    
    test_path = os.path.join(out_dir, "test.tfrecords")
    moved_paths = {klass: os.path.join(out_dir, "moved_{}.tfrecords".format(klass))
                   for klass in class_move_dict}
    test_writer = TFRecordWriter(test_path)
    moved_writer = {klass: TFRecordWriter(path) for klass, path in moved_paths.items()}
    # Número de elementos de cada clase vistos
    class_seen = {klass: 0 for klass in class_move_dict}
    majority_class = max(class_counter_train, key = lambda x: class_counter_train[x])
    
    class_counter_train = Counter()
    class_counter_test = Counter()
    
    rng = Generator(PCG64(seed))
    progress_bar = tqdm(total=total_size, desc="Dividiendo train y test")
    idxs = list(range(num_shards))
    with ExitStack() as e:
        e.enter_context(test_writer)
        for _, tfw in moved_writer.items():
            e.enter_context(tfw)
        for tfw in writers:
            e.enter_context(tfw)

        for i, (x, aux, y) in enumerate(data):
            y = y.numpy()
            if i < train_size: # Elemento pertenece al conjunto de entrenamiento
                # Se guardan los elementos a mover aparte
                if class_seen[y] < class_move_dict[y]:
                    escribir_ejemplo_stacked(moved_writer[y], x, aux, y)
                    class_counter_test.update([y])
                else:
                    # Undersampling: se ignoran las primeras
                    if y == majority_class:
                        if class_seen[y] > class_counter_train[y] - target_size_undersample:
                            writer_idx = random.choice(idxs)
                            escribir_ejemplo_stacked(writers[writer_idx], x, aux, y)
                            class_counter_train.update([y])
                    # Oversampling: se añaden las nuevas muestras aparte
                    else:
                        # se escribe normalmente
                        writer_idx = random.choice(idxs)
                        escribir_ejemplo_stacked(writers[writer_idx], x, aux, y)
                        class_counter_train.update([y])
                        # Se aplica oversampling
                        for over_i in range(factor_oversample):
                            writer_idx = random.choice(idxs)
                            ss = rng.integers(2**32 - 1)
                            ss = 42
                            over_x = augmentar_secuencia(x, seed = ss)
                            over_aux = augmentar_secuencia(aux, seed= ss)
                            escribir_ejemplo_stacked(writers[writer_idx], over_x, over_aux, y)
                            class_counter_train.update([y])
                        
            else:  # El elemento es del conjunto de test
                escribir_ejemplo_stacked(test_writer, x, aux, y)
                class_counter_test.update([y])
            class_seen[y] += 1
            progress_bar.update(1)
    progress_bar.close()
    

    train_paths = files_out
    sorted_moved_paths = list(
                itertools.starmap(
                    lambda x, y: y, sorted(moved_paths.items(), key = lambda x: x[0])
                )
    )
    test_paths = [test_path] + sorted_moved_paths
    return train_paths, test_paths, class_counter_train, class_counter_test



def first_seq_stream_full_random(files, target_size, out_dir = "ordered_sequences",
                                num_shards = 50):
    """
    Guarda los datos de la primera secuencia en orden aleatorio y los distribuye
    en varios archivos

    Parameters
    ----------
    files : List[str]
        Ficheros donde se guardaron las secuencias originales
    target_size : int
        Número total de elementos guardados en ficheros
    out_dir : str, optional
        Directorio donde se guardan los datos.
        The default is "ordered_sequences".
    num_shards : int, optional
        Número de archivos en los que guardar los datos.
        The default is 100.

    Returns
    -------
    files_out : List[str]
        Lista de archivos en los que se guardan los datos de entrenamiento
    class_counter_train : Dict[Number, int]
        Número de elementos de cada clase en datos
    """
    pathlib.Path(out_dir).mkdir(parents = True, exist_ok = True)
    data = itertools.chain(*[load_file_stacked(f) for f in files])
    class_counter_train = Counter()
    
    files_out = [os.path.join(out_dir, "fs_shard_{}.tfrecords".format(i)) for i in range(num_shards)]
    writers = [TFRecordWriter(path) for path in files_out]

    data = itertools.chain(*[load_file_stacked(f) for f in files])
    
    progress_bar = tqdm(total=target_size, desc="Dividiendo train y test")
    
    idxs = list(range(num_shards))
    with ExitStack() as e:
        for tfw in writers:
            e.enter_context(tfw)

        for i, (x, aux, y) in enumerate(data):
            y = y.numpy()
            if i < target_size: 
                writer_idx = random.choice(idxs)
                escribir_ejemplo_stacked(writers[writer_idx], x, aux, y)
                class_counter_train.update([y])
            else:
                break                        
            progress_bar.update(1)
    progress_bar.close()
    
    return files_out, class_counter_train




from tensorflow.keras import backend as keras_backend
from tensorflow.keras.callbacks import Callback


def collect_trash():
    """
    Función que se encarga de llamar al recogedor de basura y limpiar la
    sesión de keras para evitar fuga de memoria
    """
    gc.collect()
    gc.collect()
    keras_backend.clear_session()
    gc.collect()
    gc.collect()


class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        collect_trash()  
        collect_trash()


def load_dataset_for_train(file_paths, batch_size):
    """
    Función para cargar datos de diferentes ficheros. Para asegurar que se
    minimiza la correlación del orden de los datos, se intercalan los elementos
    de diferentes ficheros a la vez.

    Parameters
    ----------
    file_paths : List[str]
        Lista de ficheros donde están guardados los datos que se desea extraer
    batch_size : int
        Tamaño del batch que devolverá el dataset
    Returns
    -------
    tf.data.Dataset
        El conjunto de datos

    """
    dataset = tf.data.Dataset.from_tensor_slices(file_paths)
    dataset = dataset.flat_map(lambda x: load_file_stacked(x))    

    return dataset.batch(batch_size)


def load_dataset_for_train_randomized(file_paths, batch_size,
                           cycle_length = None, 
                           block_length = 1):
    """
    Función para cargar datos de diferentes ficheros. Para asegurar que se
    minimiza la correlación del orden de los datos, se intercalan los elementos
    de diferentes ficheros a la vez.

    Parameters
    ----------
    file_paths : List[str]
        Lista de ficheros donde están guardados los datos que se desea extraer
    x_shape : tuple[int]
        shape de los datos de entrada
    y_shape : tuple[int]
        shape de las etiquetas
    aux_shape : tuple[int]
        shape de los datos de entrada auxiliar
    batch_size : int
        Tamaño del batch que devolverá el dataset
    clave : Hashable, optional
        Clave usada para guardar cada elemento. The default is "total".
    cycle_length : int, optional
        Número de ficheros simultáneos para los que se intercalan los datos.
        Se extraen elementos del fichero hasta que se agota y se sustituye por
        el siguiente.
        The default is None.
    block_length : int, optional
        Número de elementos extraídos de cada fichero hasta pasar al siguiente.
        The default is 1.

    Returns
    -------
    tf.data.Dataset
        El conjunto de datos

    """
    cycle_length = cycle_length or len(file_paths)
    dataset = tf.data.Dataset.from_tensor_slices(file_paths)
    # En lugar de devolver los datos de un archivo entero y 
    # despues los de otro, se intercalan.
    dataset = dataset.interleave(lambda x:
        load_file_stacked(x),
        cycle_length=cycle_length, block_length=block_length)    
    return dataset.batch(batch_size)

def construct_labels(freq_dict):
    return np.fromiter(
        itertools.chain(
            *[[klass] * freq for klass, freq in freq_dict.items()]    
        )
        , dtype = np.float32
    )

class ImageSequence(keras.utils.Sequence):
    """
    Para pasar el conjunto de datos de entrenamiento y test/validación sin 
    tener que cargarlo en memoria.
    """
    def __init__(self, file_paths, 
                  batch_size, num_instances,
                  limit = None, collect_interval = None):
        super().__init__()
        self.file_paths = list(file_paths)
        self.num_instances = num_instances
        self.batch_size = batch_size
        self.limit  = limit
        self.collect_interval = collect_interval
        self.prepare_dataset()
        
    def prepare_dataset(self):
        # Cambia el orden de los archivos
        dataset = load_dataset_for_train(self.file_paths, self.batch_size)
        self.data = dataset
        self.iterator = iter(dataset)
        self.reserve_iterator = iter(dataset)
        self.count = 0
    
    def check_collect(self):
        if self.collect_interval and self.count % self.collect_interval == 0:
            collect_trash()
            collect_trash()


    def __len__(self):
        if self.limit is not None:
            return math.ceil(self.limit / self.batch_size)
        # Return number of batches.
        return math.ceil(self.num_instances / self.batch_size)

    def __getitem__(self, idx):
        """
        Keras especifica que se debería devolver el batch en la posición idx.
        Como tenemos los ficheros en disco, y los intercalamos para producir un
        batch, devolver el correspondiente al índice requeriría recorrer el
        iterador hasta llegar a idx por cada petición.
        
        Hay alternativas:
            - Guardar los ficheros de modo que cada fichero se corresponda a un
            batch. En este caso simplemente cargaríamos el fichero en el índice
            correspondiente y devolveríamos sus contenidos enteros.
            El problema (que puede no ser para tanto) que le veo es que dificulta
            el shuffling de los datos, ya que para diferentes epochs, cada
            batch tendría el mismo contenido.
            - Hacer el recorrido del iterador para cada petición. Si se fija
            shuffle = False en el método fit de Model, los batches se piden
            en orden, así que se podría reutilizar el iterador y en la práctica
            no se perdería mucho tiempo. Me he abstenido por si acaso se desea 
            aplicar paralelismo: en ese caso los índices se pedirían fuera de
            orden y se podría ralentizar el entrenamiento.
            No obstante, no he medido el tiempo que se tardaría en iterar sobre
            el dataset, es posible que no sea demasiado.
        
        Lo que se hace es devolver el primer batch del iterador de reserva
        (keras los saca para computar tipos de datos y shape) y después iterar
        sobre el conjunto de datos cargado, devolviendo los batches en orden.
        """
        if self.count:
            x, aux, y = next(self.iterator)
        else:
            x, aux, y = next(self.reserve_iterator)
            del self.reserve_iterator
        self.count += 1
        # La documentación menciona que se esperan los datos en formato
        # (datos, etiquetas).
        self.check_collect()
        return ((x, aux), y)

    def on_epoch_end(self):
        # Este método se llamará al final de cada epoch. Lo que hacemos es 
        # modificar el orden de los ficheros a cargar para simular shuffling
        # de los datos
        self.prepare_dataset()


class ImageSequenceRandomized(keras.utils.Sequence):
    """
    Para pasar el conjunto de datos de entrenamiento y test/validación sin 
    tener que cargarlo en memoria. Aleatoriza el orden de los elementos 
    de los archivos pasados en cada epoch
    """
    def __init__(self, file_paths, 
                  batch_size, num_instances,
                  limit = None, seed = 42,
                  block_length = 1,
                  cycle_length = None):
        super().__init__()
        self.file_paths = list(file_paths)
        self.num_instances = num_instances
        self.batch_size = batch_size
        self.limit = limit
        self.rng = Generator(PCG64(seed))
        self.block_length = block_length
        self.cycle_length = cycle_length
        self.prepare_dataset()
        
    def prepare_dataset(self):
        # Cambia el orden de los archivos
        self.rng.shuffle(self.file_paths)
        dataset = load_dataset_for_train_randomized(
            self.file_paths, self.batch_size, self.cycle_length, self.block_length
        )
        self.data = dataset
        self.iterator = iter(dataset)
        self.reserve_iterator = iter(dataset)
        self.count = 0
    
    def __len__(self):
        if self.limit is not None:
            return math.ceil(self.limit / self.batch_size)
        # Return number of batches.
        return math.ceil(self.num_instances / self.batch_size)

    def __getitem__(self, idx):
        """
        Keras especifica que se debería devolver el batch en la posición idx.
        Como tenemos los ficheros en disco, y los intercalamos para producir un
        batch, devolver el correspondiente al índice requeriría recorrer el
        iterador hasta llegar a idx por cada petición.
        
        Hay alternativas:
            - Guardar los ficheros de modo que cada fichero se corresponda a un
            batch. En este caso simplemente cargaríamos el fichero en el índice
            correspondiente y devolveríamos sus contenidos enteros.
            El problema (que puede no ser para tanto) que le veo es que dificulta
            el shuffling de los datos, ya que para diferentes epochs, cada
            batch tendría el mismo contenido.
            - Hacer el recorrido del iterador para cada petición. Si se fija
            shuffle = False en el método fit de Model, los batches se piden
            en orden, así que se podría reutilizar el iterador y en la práctica
            no se perdería mucho tiempo. Me he abstenido por si acaso se desea 
            aplicar paralelismo: en ese caso los índices se pedirían fuera de
            orden y se podría ralentizar el entrenamiento.
            No obstante, no he medido el tiempo que se tardaría en iterar sobre
            el dataset, es posible que no sea demasiado.
        
        Lo que se hace es devolver el primer batch del iterador de reserva
        (keras los saca para computar tipos de datos y shape) y después iterar
        sobre el conjunto de datos cargado, devolviendo los batches en orden.
        """
        if self.count:
            x, aux, y = next(self.iterator)
        else:
            x, aux, y = next(self.reserve_iterator)
            del self.reserve_iterator
        self.count += 1
        # La documentación menciona que se esperan los datos en formato
        # (datos, etiquetas).
        return ((x, aux), y)

    def on_epoch_end(self):
        # Este método se llamará al final de cada epoch. Lo que hacemos es 
        # modificar el orden de los ficheros a cargar para simular shuffling
        # de los datos
        self.prepare_dataset()



if __name__ == "__main__":
    pass