import logging
from typing import Dict, List, Optional

from fast_downward_api import get_optimal_actions_using_fd_with_timeout
from training_data import StateValuePair

_log = logging.getLogger(__name__)


def _get_state_value_pairs(
    problem: str, optimal_plan, trajectory
) -> List[StateValuePair]:
    # Check some edge cases
    if optimal_plan is None:
        _log.error(f"Unable to find optimal solution for {problem}")
        return []
    elif len(optimal_plan) == 0:
        _log.warning(f"Initial state for {problem} is already a goal state!")
        return []

    # Compute total plan cost
    remaining_cost = sum([cost for op_name, cost in optimal_plan], 0)

    state_value_pairs = list()

    for i, state in enumerate(trajectory):
        operator, cost = optimal_plan[i]
        pair = StateValuePair(state, remaining_cost)
        state_value_pairs.append(pair)
        remaining_cost = remaining_cost - cost

    assert len(state_value_pairs) == len(trajectory)
    return state_value_pairs


def _generate_optimal_state_value_pairs_for_problem(
    domain, problem, timeout, sas_name="output.sas", plan_name="sas_plan",
) -> List[StateValuePair]:
    """
    Generates the optimal state-value pairs for a planning problem.

    Parameters
    ----------
    problem: STRIPSProblem, the problem we are generating state-value pairs for

    Returns
    -------
    List[StateValuePair] with the trajectories of the states and optimal
    heuristic values
    """
    # Run Fast-Downward to get the optimal plan
    optimal_plan, trajectory = get_optimal_actions_using_fd_with_timeout(
        domain, problem, timeout, sas_name, plan_name
    )
    return _get_state_value_pairs(problem, optimal_plan, trajectory)
