import os
from typing import List, NamedTuple
from os import listdir
from os.path import isfile, join

from strips_hgn.features.global_features import (
    EmptyGlobalFeatureMapper,
    NumberOfNodesAndEdgesGlobalFeatureMapper,
)
from strips_hgn.features.hyperedge_features import ComplexHyperedgeFeatureMapper
from strips_hgn.features.node_features import PropositionInStateAndGoal
from strips_hgn.utils import Number
from strips_hgn.utils.args import JSONTrainingArgs

JSON_DEFAULT_ARGS = {
    # Training data generation and k-fold
    "num_bins": 4,
    "num_folds": 10,
    # Feature Mappers
    "global_feature_mapper_cls": EmptyGlobalFeatureMapper,
    "node_feature_mapper_cls": PropositionInStateAndGoal,
    "hyperedge_feature_mapper_cls": ComplexHyperedgeFeatureMapper,
    # Training Hyperparameters
    "batch_size": 1,
    "learning_rate": 0.001,
    "weight_decay": 2.5e-4,
    "max_epochs": 9999,
    # TODO: play around with this parameter - bumped to very high for now
    "patience": 999,
    # Others
    "debug": True,
    "remove_duplicates": False,
    "shuffle": True,
}


class JSONConfiguration(NamedTuple):
    base_directory: str
    instance_directory: str
    domain_pddl: str
    json_dir: str

    @property
    def domain(self) -> str:
        """ Actual domain file name """
        return os.path.join(self.base_directory, self.domain_pddl)

    @property
    def instance_dir(self) -> str:
        """ Directory of instances which the json data was generated for """
        return os.path.join(self.base_directory, self.instance_directory)

    @property
    def json_files(self) -> str:
        """ List of json files """
        directory = os.path.join(self.base_directory, self.json_dir)
        print(directory)
        return [
            join(directory, f) for f in listdir(directory) if isfile(join(directory, f))
        ]


def get_training_args(
    configurations: List[JSONConfiguration], max_training_time: Number, **override
) -> JSONTrainingArgs:
    assert configurations, "At least one configuration must be provided!"

    # Override default arguments if required
    training_kwargs = JSON_DEFAULT_ARGS.copy()
    for key, value in override.items():
        training_kwargs[key] = value

    assert len(configurations) == 1, "Too many training configurations"
    configuration = configurations[0]
    return JSONTrainingArgs(
        domain=configuration.domain,
        domains=None,
        problems=None,
        json_files=configuration.json_files,
        instance_dir=configuration.instance_dir,
        max_training_time=max_training_time,
        **training_kwargs
    )
