from __future__ import annotations

from collections import deque
from collections.abc import Iterable
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Optional

import manim

from automata.base.animation import Animate, _ManimEdge, _ManimInput, _ManimNode
from automata.base.exceptions import RejectionException
from automata.fa.fa import FA, FAStateT

if TYPE_CHECKING:
    from automata.fa.dfa import DFA, DFAStateT
    from automata.fa.nfa import NFA, NFAStateT


class _FAGraph:
    """
    The `FAGraph` class is the common class for FA to draw a picture of a `FA` object by
    representing the `AGraph` object generated by `fa.diagram` with manim.

    Parameters
    ----------
    graph : VDict
        The graph of the `FA` object. The keys are the states (with type `FAStateT`) and
        the transitions (with type `tuple[FAStateT, FAStateT]`, which is their tail and
        head). The values are Manim objects. There is another `None` key corresponding
        to the `nullnode` in the `FA.show_diagram`, which is the start point of the
        arrow pointing to the initial state.
    """

    graph: manim.VDict

    def __init__(self, fa: FA) -> None:
        """
        Generate the graph and put it on the down side of the screen.

        Parameters
        ----------
        fa : FA
            The FA object based to create the graph.

        Note
        ----
        As states are discriminated by `FA._get_state_name`, they are not allowed to
        have same names, nor named as 'None'.
        """
        name_state = {FA._get_state_name(state): state for state in fa.states}
        self.graph = manim.VDict(
            {name_state.get(node): _ManimNode(node) for node in fa.diagram.nodes_iter()}
            | {
                tuple(map(name_state.get, edge)): _ManimEdge(edge)
                for edge in fa.diagram.edges_iter()
            }
        )
        self.graph.scale_to_fit_width(manim.config.frame_width)
        self.graph.align_on_border(manim.LEFT, buff=0).align_on_border(manim.DOWN)

    def highlight_states(
        self, states: Iterable[FAStateT]
    ) -> Iterable[manim.ApplyMethod]:
        """
        Parameters
        ----------
        states : Iterable[FAStateT]
            the states to highlight.
        Returns
        -------
        `Iterable[ApplyMethod]`
            The animations for the `Scene` object to `play`.
        """
        yield from (
            Animate.highlight(self.graph[new_state])
            for new_state in states
            if new_state is not None
        )

    def change_states(
        self, old_states: Iterable[FAStateT], new_states: Iterable[FAStateT]
    ) -> Iterable[manim.ApplyMethod]:
        """
        Turn `old_states` to default color and highlight the `new_states`.

        States occured in both `old_states` and `new_states` will still be
        highlighted.

        Parameters
        ----------
        old_states : Iterable[FAStateT]
            The states to turn to default color.
        new_states : Iterable[FAStateT]
            The states to highlight.

        Returns
        -------
        Iterable[ApplyMethod]
            The animations for the `Scene` object to `play`.
        """
        yield from (
            Animate.to_default_color(self.graph[old_state]) for old_state in old_states
        )
        yield from self.highlight_states(new_states)

    def change_transitions(
        self,
        old_transitions: Iterable[tuple[FAStateT, FAStateT]],
        new_transitions: Iterable[tuple[FAStateT, FAStateT]],
    ) -> Iterable[manim.ApplyMethod]:
        """
        Turn `old_transitions` to default color and highlight the `new_transitions`.

        Transitions occured in both `old_transitions` and `new_transitions` will
        still be highlighted.

        Parameters
        ----------
        old_transitions : Iterable[FAStateT]
            The transitions to turn to default color.
        new_transitions : Iterable[FAStateT]
            The transitions to highlight.

        Returns
        -------
        Iterable[ApplyMethod]
            The animations for the `Scene` object to `play`.
        """
        yield from (
            Animate.to_default_color(self.graph[old_transition])
            for old_transition in old_transitions
        )
        yield from (
            Animate.highlight(self.graph[new_transition])
            for new_transition in new_transitions
            if new_transition[-1] is not None
        )

    def clean(
        self, transitions: Iterable[tuple[FAStateT, FAStateT]]
    ) -> Iterable[manim.ApplyMethod]:
        """
        Cancel all the highlighted elements.

        Parameters
        ----------
        transitions : Iterable[tuple[FAStateT, FAStateT]]
            The highlighed transitions to cancel the highlight.

        Returns
        -------
        The animations for `Scene` object to `play`.
        """
        yield from (
            Animate.to_default_color(self.graph[transition])
            for transition in transitions
            if transition[1] is not None
        )


class _DFAAnimation(manim.Scene):
    """
    The `_DFAAnimation` class is the class to generate the animation of a DFA
    identifying an input string.

    To generate the animation, use `_DFAAnimation.render`, which will call the
    `setup` method and the `construct` method.

    Parameters
    ----------
    dfa : DFA
        The DFA object.
    input_str : str
        The string to identify.
    dfa_graph : _FAGraph
        The graph of the DFA.
    input_symbols : _ManimInput
        The element in animation to show the current symbol of the `input_str`.
    """

    dfa: DFA
    input_str: str
    dfa_graph: _FAGraph
    input_symbols: _ManimInput

    def __init__(self, dfa: DFA, input_str: str, **kwargs: Any) -> None:
        """
        Parameters
        ----------
        dfa : DFA
            `self.dfa`
        input_str : str
            `self.input_str`
        **kwargs
            The arguments for the `Scene`. Better to keep the default.
        """
        super().__init__(**kwargs)
        self.dfa = dfa
        self.dfa_graph = _FAGraph(self.dfa)
        self.input_str = input_str
        self.input_symbols = _ManimInput(self.input_str)

    def setup(self) -> None:
        """Put the diagram and the input string on the screen."""
        self.add(self.dfa_graph.graph)
        self.add(self.input_symbols)
        self.dfa_graph.graph[None].set_color(Animate.HIGHLIGHT_COLOR)

    def construct(self) -> None:
        """Construct the animation of `self.dfa` identifying `self.input_str`."""
        states_queue: deque[Optional[DFAStateT]] = deque(maxlen=3)
        states_queue.append(None)
        try:
            for symbol_index, next_state in enumerate(
                self.dfa.read_input_stepwise(self.input_str), start=-1
            ):
                states_queue.append(next_state)
                self.play(
                    *self.dfa_graph.change_transitions(
                        (
                            ((states_queue[0], states_queue[1]),)
                            if len(states_queue) >= 3
                            else ()
                        ),
                        ((states_queue[-2], states_queue[-1]),),
                    ),
                    *self.input_symbols.change_symbol(symbol_index),
                )
                self.play(
                    *self.dfa_graph.change_states(
                        (states_queue[-2],), (states_queue[-1],)
                    )
                )
                self.wait()
                if next_state is None:
                    raise RejectionException
            accepts_input = True
        except RejectionException:
            accepts_input = False
        self.play(
            *self.dfa_graph.clean(((states_queue[-2], states_queue[-1]),)),
            self.input_symbols.show_result(accepts_input),
        )
        self.wait()


class _NFAAnimation(manim.Scene):
    """
    The `_NFAAnimation` class is the class to generate the animation of a NFA
    identifying an input string.

    To generate the animation, use `_NFAAnimation.render`, which will call the
    `setup` method and the `construct` method.

    Parameters
    ----------
    nfa : NFA
        The NFA object.
    input_str : str
        The string to identify.
    nfa_graph : _FAGraph
        The graph of the NFA.
    input_symbols : _ManimInput
        The element in animation to show the current symbol of the `input_str`.
    """

    nfa: NFA
    input_str: str
    nfa_graph: _FAGraph
    input_symbols: _ManimInput

    def __init__(self, nfa: NFA, input_str: str, **kwargs: Any) -> None:
        """
        Parameters
        ----------
        dfa : NFA
            `self.nfa`
        input_str : str
            `self.input_str`
        **kwargs
            The arguments for the `Scene`. Better to keep the default.
        """
        super().__init__(**kwargs)
        self.nfa = nfa
        self.nfa_graph = _FAGraph(nfa)
        self.input_str = input_str
        self.input_symbols = _ManimInput(self.input_str)

    def setup(self) -> None:
        """Put the diagram and the input string on the screen."""
        self.add(self.nfa_graph.graph)
        self.add(self.input_symbols)
        self.nfa_graph.graph[None].set_color(Animate.HIGHLIGHT_COLOR)

    def construct(self) -> None:
        """Construct the animation of `self.nfa` identifying `self.input_str`."""
        self.play(Animate.highlight(self.nfa_graph.graph[None, self.nfa.initial_state]))
        self.play(
            Animate.to_default_color(self.nfa_graph.graph[None]),
            Animate.highlight(self.nfa_graph.graph[self.nfa.initial_state]),
        )
        current_transitions: tuple[tuple[Optional[NFAStateT], NFAStateT], ...] = (
            (None, self.nfa.initial_state),
        )
        current_states = {self.nfa.initial_state}

        def add_lambda_transitions() -> Iterable[tuple["NFAStateT", "NFAStateT"]]:
            """
            Add states reachable with lambda transitions into `current_states` and

            Returns
            -------
            Iterable[tuple[NFAStateT, NFAStateT]]
                The lambda transitions that can reach new states from one of
                `current_states`.
            """
            states_queue = deque(current_states)
            while states_queue:
                state = states_queue.popleft()
                if next_states := self.nfa.transitions[state].get(""):
                    for next_state in next_states:
                        if next_state not in current_states:
                            current_states.add(next_state)
                            states_queue.append(next_state)
                            yield (state, next_state)

        if lambda_transitions := tuple(add_lambda_transitions()):
            self.play(
                *self.nfa_graph.change_transitions(
                    current_transitions, lambda_transitions
                ),
                *self.nfa_graph.highlight_states(
                    map(itemgetter(1), lambda_transitions)
                ),
            )
            current_transitions = lambda_transitions
        self.wait()
        for input_index, input_symbol in enumerate(self.input_str):
            new_transitions = tuple(
                (current_state, next_state)
                for current_state in current_states
                if (
                    next_states := self.nfa.transitions[current_state].get(input_symbol)
                )
                for next_state in next_states
            )
            new_states = set(map(itemgetter(1), new_transitions))
            self.play(
                *self.nfa_graph.change_transitions(
                    current_transitions, new_transitions
                ),
                *self.input_symbols.change_symbol(input_index),
            )
            current_transitions = new_transitions
            self.play(*self.nfa_graph.change_states(current_states, new_states))
            self.wait()
            current_states = new_states
            if not current_states:
                break
            if lambda_transitions := tuple(add_lambda_transitions()):
                self.play(
                    *self.nfa_graph.change_transitions(
                        current_transitions, lambda_transitions
                    ),
                    *self.nfa_graph.highlight_states(
                        map(itemgetter(1), lambda_transitions)
                    ),
                )
                self.wait()
                current_transitions = lambda_transitions
        self.play(
            *self.nfa_graph.clean(current_transitions),
            self.input_symbols.show_result(
                not current_states.isdisjoint(self.nfa.final_states)
            ),
        )
        self.wait()
