import os
from typing import List, NamedTuple

from strips_hgn.features.global_features import (
    EmptyGlobalFeatureMapper,
    NumberOfNodesAndEdgesGlobalFeatureMapper,
)
from strips_hgn.features.hyperedge_features import ComplexHyperedgeFeatureMapper
from strips_hgn.features.node_features import PropositionInStateAndGoal
from strips_hgn.utils import Number
from strips_hgn.utils.args import IterativeTrainingArgs

from generators.generator import SingleParameterGenerator

DEFAULT_ARGS = {
    # Feature Mappers
    "global_feature_mapper_cls": EmptyGlobalFeatureMapper,
    "node_feature_mapper_cls": PropositionInStateAndGoal,
    "hyperedge_feature_mapper_cls": ComplexHyperedgeFeatureMapper,
    # Training Hyperparameters
    "batch_size": 1,
    "learning_rate": 0.001,
    "weight_decay": 2.5e-4,
    "max_epochs": 9999,
    # TODO: play around with this parameter - bumped to very high for now
    "patience": 20,
    # Others
    "debug": True,
    "fd_timeout": 1 * 10,
    "initialize_time": 1 * 20,
}


class GeneratorConfiguration(NamedTuple):
    base_directory: str
    domain_pddl: str
    call_string: str
    min_problem_size: int
    max_problem_size: int
    generator_type: str

    @property
    def min_problem_size(self) -> int:
        return min_problem_size

    @property
    def max_problem_size(self) -> int:
        return max_problem_size

    @property
    def generator_type(self) -> str:
        return generator_type

    @property
    def domain(self) -> str:
        return os.path.join(self.base_directory, self.domain_pddl)

    @property
    def generator_call(self):
        return self.base_directory + self.call_string


def get_training_args(
    configuration: GeneratorConfiguration, max_training_time: Number, **override
) -> IterativeTrainingArgs:

    # Create generator
    assert configuration.generator_type == "single_parameter", "Unknown generator type"

    generator = SingleParameterGenerator(
        configuration.domain,
        configuration.generator_call,
        configuration.min_problem_size,
        configuration.max_problem_size,
    )

    # Override default arguments if required
    training_kwargs = DEFAULT_ARGS.copy()
    for key, value in override.items():
        training_kwargs[key] = value

    return IterativeTrainingArgs(
        domain=configuration.domain,
        domains=None,
        problems=None,
        max_training_time=max_training_time,
        generator=generator,
        **training_kwargs,
    )
