#!/usr/bin/env python

import logging
import os
import tempfile
from shutil import copyfile
from typing import List, Optional, Tuple, Type

from strips_hgn.hypergraph.delete_relaxation import DeleteRelaxationHypergraphView

from strips_hgn.features import (
    GlobalFeatureMapper,
    HyperedgeFeatureMapper,
    NodeFeatureMapper,
)

from torch.utils.data import DataLoader, IterableDataset

from strips_hgn.torch_utils.dataloaders import _collate_hypergraphs_tuples

from strips_hgn.models.strips_hgn import STRIPSHGN
from strips_hgn.utils.args import (
    IterativeTrainingArgs,
    parse_and_validate_training_args,
)
from strips_hgn.utils.helpers import Namespace
from strips_hgn.utils.metrics import CountMetric, metrics_logger
from strips_hgn.utils.timer import TimedOperation, timed
from strips_hgn.utils.wrapper import wrap_method
from strips_hgn.workflows import (
    KFoldTrainingDataWorkflow,
    TrainIterativeSTRIPSHGNWorkflow,
    BaseTrainingDataWorkflow,
)

from strips_hgn.training_data import TrainingPair

from strips_hgn.training_data.generate import (
    _generate_optimal_state_value_pairs_for_problem,
)

from generators.generator import Generator

_log = logging.getLogger(__name__)

_RESULTS_DIRECTORY = os.path.join(
    os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "results"
)

_BEST_MODEL_FNAME = "model-best.ckpt"


class WorkflowDummy(BaseTrainingDataWorkflow):
    """
    Dummy workflow to get access to workflow methods that we need.
    """

    def __init__(
        self,
        global_feature_mapper_cls: Type[GlobalFeatureMapper],
        node_feature_mapper_cls: Type[NodeFeatureMapper],
        hyperedge_feature_mapper_cls: Type[HyperedgeFeatureMapper],
        max_receivers,
        max_senders,
        experiment_dir: str,
    ):
        super().__init__(
            [],
            global_feature_mapper_cls=global_feature_mapper_cls,
            node_feature_mapper_cls=node_feature_mapper_cls,
            hyperedge_feature_mapper_cls=hyperedge_feature_mapper_cls,
            experiment_dir=experiment_dir,
        )
        # Manually set max receivers and max senders
        self.max_receivers = max_receivers
        self.max_senders = max_senders

    def run(self):
        pass


class IterableTrainingData(IterableDataset):
    def __init__(self, base_workflow, generator):
        assert generator.is_initialized()
        self.generator = generator
        self.workflow = base_workflow
        self.training_pairs = list()

    def generate_training_pairs(self):
        """
        Samples an instance of the generator, solves it with A* and lmcut
        and returns a training data pair.
        """
        # Sample a new problem (STRIPSProblem)
        tf = tempfile.NamedTemporaryFile()
        problem = self.generator.sample_problem(tf.name)
        for line in tf:
            _log.debug(line)

        # ignore states where the goal is empty
        if len(problem.goals) == 0:
            return

        state_value_pairs = _generate_optimal_state_value_pairs_for_problem(problem)
        training_data = {problem: state_value_pairs}

        for pair in state_value_pairs:
            training_pair = TrainingPair(problem, pair)

            # Convert to HypergraphsTupleTrainingPair
            hg_tuple = self.workflow._create_input_and_target_hypergraphs_tuple(
                training_pair, DeleteRelaxationHypergraphView(training_pair.problem)
            )
            self.training_pairs.append(hg_tuple)

    def __iter__(self):
        return self

    def __next__(self):
        # If we currently have not stored any training pairs we have to generate
        # new ones
        # print(f"training data number: {len(self.training_pairs)}")
        while not self.training_pairs:
            self.generate_training_pairs()
        # Remove the last training pair in the list of training pairs
        res = self.training_pairs.pop()
        return res


def iterative_train(args: IterativeTrainingArgs, experiments_dir):

    train_timer = TimedOperation("TrainingTime").start()
    fold_idx = 0
    generator = args.generator
    generator.initialize_parameter_space(args.fd_timeout, args.initialize_time)

    # We use the hardest problem that we can generate with the generator to
    # determine the maximum number of senders and receivers
    max_receivers = generator.max_receivers()
    max_senders = generator.max_senders()

    # Workflow required to transform states to hypergraphs
    workflow = WorkflowDummy(
        global_feature_mapper_cls=args.global_feature_mapper_cls,
        node_feature_mapper_cls=args.node_feature_mapper_cls,
        hyperedge_feature_mapper_cls=args.hyperedge_feature_mapper_cls,
        max_receivers=max_receivers,
        max_senders=max_senders,
        experiment_dir=experiments_dir,
    )

    data = IterableTrainingData(workflow, generator)
    # Convert to torch DataLoaders
    # We handle batch size manually, because data points in batches are assumed
    # to be individual, which is not the case with our training data
    dataloader = DataLoader(
        dataset=data, batch_size=5, collate_fn=_collate_hypergraphs_tuples
    )

    # Hyperparameter for STRIPS-HGN
    strips_hgn_hparams = Namespace(
        receiver_k=max_receivers,
        sender_k=max_senders,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        global_feature_mapper_cls=args.global_feature_mapper_cls,
        node_feature_mapper_cls=args.node_feature_mapper_cls,
        hyperedge_feature_mapper_cls=args.hyperedge_feature_mapper_cls,
    )

    # Create training workflow and run
    current_train_wf = TrainIterativeSTRIPSHGNWorkflow(
        strips_hgn=STRIPSHGN(hparams=strips_hgn_hparams),
        max_training_time=args.max_training_time,
        max_num_epochs=args.max_epochs,
        train_dataloader=dataloader,
        val_dataloader=dataloader,
        experiments_dir=experiments_dir,
        prefix=f"fold_{fold_idx}",
        early_stopping_patience=args.patience,
    )
    # TODO adapt training for loss decrease
    current_train_wf.run()

    # Stop the timer so it saves as a metric
    train_timer.stop()

    # Copy model
    dest_filename = os.path.join(current_train_wf.checkpoint_dir, _BEST_MODEL_FNAME)
    copyfile(current_train_wf.best_val_loss_checkpoint, dest_filename)

    # Add metric for number of epochs trained for
    metrics_logger.add_metric(
        CountMetric(
            "NumberOfEpochsTrained",
            current_train_wf.current_epoch + 1,
            context={"fold_idx": fold_idx},
        )
    )


def iterative_train_wrapper(args: IterativeTrainingArgs):
    # Wrap the training method
    wrap_method(
        args=args,
        wrapped_method=iterative_train,
        experiment_type="iterative_train",
        results_directory=_RESULTS_DIRECTORY,
    )


if __name__ == "__main__":
    train_wrapper(args=parse_and_validate_training_args())
