import copy
from dataclasses import dataclass, field
from typing import Callable, Self, cast
from random import Random

from lark import Token
from lark import ParseTree as Tree

from mlirmut.utils import get_target_node
from mlirmut.tree_utils import Anchoring, TreeView, anchor_pattern, exact_tree_match


@dataclass(slots=True)
class TreeMask:
    is_terminal: bool
    data: str | None
    children: list[Self] | None
    visible: bool

    def mask_tree(self, root: Tree | Token, root_path: list[int]):
        if self.is_terminal:
            return root if self.visible else None

        def should_include_child(child_node_mask: TreeMask | Token):
            # don't include invisible children
            return child_node_mask.visible

        def mask_child(child_node_mask, child_node):
            if isinstance(child_node_mask, Token):
                return child_node
            return child_node_mask.mask_tree(root=child_node, root_path=root_path[1:])

        if self.visible:
            return Tree(
                root.data,
                [
                    mask_child(child_node_mask=child_node_mask, child_node=child_node)
                    for child_node_mask, child_node in zip(self.children, root.children)
                    if should_include_child(child_node_mask=child_node_mask)
                ],
            )
        # if the current node is not visible, then make the next node along the
        # root path the root node
        if len(root_path) == 0:
            raise ValueError(
                "The current root is not visible, "
                "but the root path is empty so we "
                "don't know what node to use as the next root."
            )

        next_root = root.children[root_path[0]]
        next_root_mask = self.children[root_path[0]]
        next_root_path = root_path[1:]
        return next_root_mask.mask_tree(root=next_root, root_path=next_root_path)

    def mask_path(self, path: list[int]):
        # if the path is empty, then there is nothing to mask
        if len(path) == 0:
            return []

        # if the current node is not visible, then we remove the first path index
        if not self.visible:
            return self.children[path[0]].mask_path(path[1:])

        # if the current node is visible, then we need to update the index to
        # account for invisible children
        masked_idx = 0
        for absolute_idx, child in enumerate(self.children):
            if absolute_idx == path[0]:
                if isinstance(child, Token):
                    # terminal node
                    return [masked_idx]
                # if the current index, is the path index, then it must be visible
                # since the current node is visible (direct ancestors of the
                # fragment must form a visible chain)
                if not child.visible:
                    raise ValueError(f"child {child} is not visible")
                # we use our new masked index in place of the original
                return [masked_idx] + child.mask_path(path[1:])

            # only increment if the child is visible
            if isinstance(child, Token) or child.visible:
                masked_idx += 1

        raise ValueError(
            f"path index {path[0]} not found in range({len(self.children)})"
        )

    @staticmethod
    def init_mask(node: Tree | Token):
        if isinstance(node, Token):
            # the HOLE token should be visible,
            # all other tokens are by default invisible
            return TreeMask(
                is_terminal=True, visible=node.type == "HOLE", data=None, children=None
            )

        # parameters should be invisible, but have their children visible
        if node.data == "__PARAMETER__":
            return TreeMask(
                is_terminal=False,
                visible=False,
                data=node.data,
                children=[
                    TreeMask(is_terminal=True, visible=True, data=None, children=None)
                    for _ in node.children
                ],
            )

        # all other nodes are invisible
        return TreeMask(
            is_terminal=False,
            visible=False,
            data=node.data,
            children=[TreeMask.init_mask(child) for child in node.children],
        )


@dataclass(slots=True)
class MatchTemplate:
    full_context: Tree
    full_fragment_path: list[int]
    fragment: Tree
    full_mask: TreeMask = field(init=False)
    masked_context: Tree = field(init=False)
    masked_fragment_path: list[int] = field(init=False)
    original_values: dict[str, Tree] = field(init=False)

    _backup_full_mask: TreeMask | None = None
    _backup_masked_fragment_path: list[int] | None = None

    def __post_init__(self):
        # initialize original_values and updates full_context and fragment
        self.parameterize_template()

        self.full_mask = TreeMask.init_mask(self.full_context)
        self.masked_context = self.full_mask.mask_tree(
            root=self.full_context, root_path=self.full_fragment_path
        )
        self.masked_fragment_path = self.full_mask.mask_path(self.full_fragment_path)

    def start_update(self):
        self._backup_full_mask = copy.deepcopy(self.full_mask)
        self._backup_masked_fragment_path = copy.deepcopy(self.masked_fragment_path)

    def abort_update(self):
        self.full_mask = self._backup_full_mask
        self.masked_fragment_path = self._backup_masked_fragment_path
        self.masked_context = self.full_mask.mask_tree(
            self.full_context, self.full_fragment_path
        )

    def commit_update(self):
        self._backup_full_mask = None
        self._backup_masked_fragment_path = None

    def masked_context_path(self):
        return self.full_fragment_path[
            : len(self.full_fragment_path) - len(self.masked_fragment_path)
        ]

    def parameterize_template(self):
        """Modifies the template in-place by replacing some nodes with special parameter tokens."""
        original_values = dict()
        n_params = 0
        node_queue = [self.full_context]
        while len(node_queue) > 0:
            current_node = node_queue.pop(0)
            for child in current_node.children:
                if not can_parameterize(child):
                    continue
                candidate_nodes = set(self.fragment.find_data(child.data))
                child_replaced = False
                for candidate_node in candidate_nodes:
                    if candidate_node == child:
                        child_replaced = True
                        original_values[n_params] = copy.deepcopy(child)
                        self.full_context = replace_all(
                            self.full_context,
                            child,
                            Tree(
                                "__PARAMETER__",
                                [
                                    Token("NODE_TYPE", child.data),
                                    Token("PARAM_ID", n_params),
                                ],
                            ),
                        )
                        self.fragment = replace_all(
                            self.fragment,
                            child,
                            Tree(
                                "__PARAMETER__",
                                [
                                    Token("NODE_TYPE", child.data),
                                    Token("PARAM_ID", n_params),
                                ],
                            ),
                        )
                        n_params += 1
                if not child_replaced:
                    node_queue.append(child)
        self.original_values = original_values

    def update_mask_and_path(self, path: list[int], visible: bool):
        get_target_node(self.full_mask, path).visible = visible

        # TODO: for debugging only, remove later
        def validate_fragment_path_visibility():
            current_node = self.full_mask
            visibility_chain = [current_node.visible]
            seen_visible = current_node.visible
            raise_error = False
            for path_entry in self.full_fragment_path[:-1]:
                current_node = current_node.children[path_entry]
                visibility_chain.append(current_node.visible)
                if seen_visible and not current_node.visible:
                    raise_error = True
                if current_node.visible:
                    seen_visible = True
            if raise_error:
                raise ValueError(
                    f"broken visibility chain: {visibility_chain} {self.full_fragment_path}"
                )

        validate_fragment_path_visibility()
        # TODO: end debugging
        self.masked_context = self.full_mask.mask_tree(
            root=self.full_context, root_path=self.full_fragment_path
        )
        self.masked_fragment_path = self.full_mask.mask_path(self.full_fragment_path)

    def generalize_ancestors(self):
        """Handles the specific case when we set a direct ancestor invisible."""
        # exclude the fragment hole and its parent in the list of ancestors
        # we can possibly set non-visible
        ancestor_masks = collect_nodes(self.full_mask, self.full_fragment_path[:-2])
        for idx in range(len(ancestor_masks)):
            mask = ancestor_masks[idx]
            if mask.visible:
                # set the mask node to invisible
                self.update_mask_and_path(self.full_fragment_path[:idx], False)
                return True
        return False

    def specialize_ancestors(self):
        """Handles the specific case when we set a direct ancestor visible."""
        # exclude the fragment hole
        ancestor_masks = collect_nodes(self.full_mask, self.full_fragment_path[:-1])

        # iterate backwards starting from the parent of the fragment
        for idx in range(len(ancestor_masks) - 1, -1, -1):
            mask = ancestor_masks[idx]
            if not mask.visible:
                # set the mask node to visible
                self.update_mask_and_path(self.full_fragment_path[:idx], True)
                return True
        return False

    def generalize_nonancestors(self, rand: Random, depth_weight: float = 0.5):
        """Randomly traverse the mask tree until we find a visible node,
        excludes direct ancestors (i.e. anything along the
        masked_fragment_path).
        """

        # start at the root node of the masked context
        target_path = self.masked_context_path()
        current_node = get_target_node(self.full_mask, target_path)

        # collect all visible branches
        branches = []
        for rel_path_idx in range(len(self.masked_fragment_path)):
            abs_path_idx = rel_path_idx + len(self.masked_context_path())
            next_node_idx = self.full_fragment_path[abs_path_idx]
            for idx, child in enumerate(current_node.children):
                if isinstance(child, Token):
                    continue
                if idx != next_node_idx and child.visible:
                    branches.append(
                        (self.full_fragment_path[:abs_path_idx] + [idx], child)
                    )
            current_node = current_node.children[next_node_idx]

        if len(branches) == 0:
            return False

        # choose a random visible branch to generalize
        branch_path, branch = rand.choice(branches)
        return self.generalize_random_node(branch, branch_path, rand, depth_weight)

    def specialize_nonancestors(self, rand: Random):
        """Randomly traverse the mask tree until we find an invisible node,
        excludes direct ancestors (i.e. anything along the
        masked_fragment_path).
        """

        # start at the root node of the masked context
        target_path = self.masked_context_path()
        current_node = get_target_node(self.full_mask, target_path)

        # collect all branches
        branches = []
        for rel_path_idx in range(len(self.masked_fragment_path)):
            abs_path_idx = rel_path_idx + len(self.masked_context_path())
            next_node_idx = self.full_fragment_path[abs_path_idx]
            for idx, child in enumerate(current_node.children):
                if isinstance(child, Token):
                    continue
                if idx != next_node_idx:
                    branches.append(
                        (self.full_fragment_path[:abs_path_idx] + [idx], child)
                    )
            current_node = current_node.children[next_node_idx]

        if len(branches) == 0:
            return False

        # choose a random branch until one can be specialized
        rand.shuffle(branches)
        for branch_path, branch in branches:
            if self.specialize_random_node(branch, branch_path, rand):
                return True
        return False

    def generalize_random_node(
        self,
        start_node: TreeMask,
        target_path: list[int],
        rand: Random,
        depth_weight: float,
    ):
        """Chooses a random node to set invisible. Assumes that the start node does not include the fragment."""
        current_node = start_node

        # randomly choose to continue down the tree or stop
        while rand.random() < depth_weight:
            # continue down the tree by choosing a visible node
            visible_children = [
                (idx, child)
                for idx, child in enumerate(current_node.children)
                if not isinstance(child, Token) and child.visible
            ]

            # if there are no visible children, we must stop
            if len(visible_children) == 0:
                break

            # choose a random child to go down
            idx, child = rand.choice(visible_children)
            current_node = child
            target_path.append(idx)

        # stop and set the current node to invisible
        self.update_mask_and_path(target_path, False)
        return True

    def specialize_random_node(
        self,
        start_node: TreeMask,
        target_path: list[int],
        rand: Random,
    ):
        """Chooses a random external node at the edge of the visible subtree to set visible. Assumes that the start node does not include the fragment."""
        current_node = start_node

        # collect all invisible nodes
        def collect_invisible_nodes(current_node: TreeMask, path: list[int]):
            branches = []
            if not current_node.visible:
                branches.append(path)
                return branches
            for idx, child in enumerate(current_node.children):
                if isinstance(child, Token):
                    continue
                branches += collect_invisible_nodes(child, path + [idx])
            return branches

        branches = collect_invisible_nodes(current_node, [])

        if len(branches) == 0:
            return False

        # randomly choose a branch to specialize
        branch_path = rand.choice(branches)
        self.update_mask_and_path(target_path + branch_path, True)
        return True

    def generalize_context(
        self,
        rand: Random,
        ancestor_weight: float = 0.5,
    ):
        # choose whether to generalize ancestors or nonancestors
        # falling back to the alternative if it fails
        if rand.random() < ancestor_weight:
            if self.generalize_ancestors():
                return True
            else:
                return self.generalize_nonancestors(rand)
        else:
            if self.generalize_nonancestors(rand):
                return True
            else:
                return self.generalize_ancestors()

    def specialize_context(
        self,
        rand: Random,
        ancestor_weight: float = 0.5,
    ):
        # choose whether to generalize ancestors or nonancestors
        # falling back to the alternative if it fails
        if rand.random() < ancestor_weight:
            if not self.specialize_ancestors():
                return self.specialize_nonancestors(rand)
        else:
            if not self.specialize_nonancestors(rand):
                return self.specialize_ancestors()

    def yield_paths_to_specialize(self):
        def yield_bfs_paths(
            current_node: TreeMask, current_path: list[int], next_child_idx: int
        ):
            node_queue = []
            # yield direct children
            for child_idx, child in enumerate(current_node.children):
                if child_idx == next_child_idx:
                    continue
                child_path = current_path + [child_idx]
                path_now_visible = yield child_path
                if path_now_visible and not child.is_terminal:
                    # we only queue visible non-terminal nodes since
                    # we cannot yield children of invisible nodes
                    # nor can we yield children of terminal nodes
                    node_queue.append((child, child_path))

            while len(node_queue) > 0:
                current_node, current_path = node_queue.pop(0)
                # yield direct children
                for child_idx, child in enumerate(current_node.children):
                    child_path = current_path + [child_idx]
                    path_now_visible = yield child_path
                    if path_now_visible and not child.is_terminal:
                        # we only queue visible non-terminal nodes since
                        # we cannot yield children of invisible nodes
                        # nor can we yield children of terminal nodes
                        node_queue.append((child, child_path))


        # iterate backwards starting from the parent of the fragment
        # len()-1 would be the index of the fragment node
        # len()-2 would be the index of the parent node
        fragment_path_idx = len(self.full_fragment_path) - 2
        while fragment_path_idx >= -1:
            # try yielding a direct ancestor first
            ancestor_path = self.full_fragment_path[
                : fragment_path_idx + 1
            ]  # add one here to include the node referenced by the index
            path_now_visible = yield ancestor_path
            if not path_now_visible:
                # we cannot yield the children of an invisible node
                # we also cannot yield further ancestors of this node
                break

            # then yield the children of the ancestors
            ancestor_node = get_target_node(
                root=self.full_mask,
                path=ancestor_path,
            )  # add one here to include the node referenced by the index

            # it is not valid to yield the children of an invisible node
            # this should have been caught by the is_valid_path check
            assert ancestor_node.visible

            yield from yield_bfs_paths(
                current_node=ancestor_node,
                current_path=self.full_fragment_path[: fragment_path_idx + 1],
                next_child_idx=self.full_fragment_path[fragment_path_idx + 1],
            )
            fragment_path_idx -= 1

    def fully_specialize(self, target_tree: Tree, initial_anchoring: Anchoring):
        # for each node path starting from the closest node to the fragment:
        node_path_generator = self.yield_paths_to_specialize()
        node_path = next(node_path_generator)
        anchor_path = initial_anchoring.path
        while True:
            # check if the path is already visible
            if get_target_node(root=self.full_mask, path=node_path).visible:
                # get the next path
                try:
                    node_path = node_path_generator.send(tree_view is not None)
                except StopIteration:
                    break
                continue

            # try setting it visible
            parent_path = self.masked_context_path()[:-1]
            self.update_mask_and_path(node_path, True)
            anchor_path_backup = anchor_path
            if node_path == parent_path:
                anchor_path = anchor_path[:-1]
            tree_view = exact_tree_match(
                tree=get_target_node(root=target_tree, path=anchor_path),
                pattern=self.masked_context,
            )

            # revert change if not valid
            if tree_view is None:
                self.update_mask_and_path(node_path, False)
                anchor_path = anchor_path_backup

            # get the next path
            try:
                node_path = node_path_generator.send(tree_view is not None)
            except StopIteration:
                break

        tree_view = exact_tree_match(
            tree=get_target_node(root=target_tree, path=anchor_path),
            pattern=self.masked_context,
        )
        assert tree_view is not None

        return Anchoring(tree_view=tree_view, path=anchor_path)


def collect_nodes(tree, path):
    nodes = [tree]
    for idx in path:
        tree = tree.children[idx]
        nodes.append(tree)
    return nodes


def can_parameterize(node):
    """A heuristic to determine if a node can be parameterized."""
    if isinstance(node, Token):
        return False
    for child in node.children:
        if isinstance(child, Token):
            return False
    return True


def replace_all(tree, old, new):
    """Replaces all instances of old in tree with new."""
    if tree == old:
        return new
    if isinstance(tree, Token):
        return tree
    return Tree(tree.data, [replace_all(child, old, new) for child in tree.children])


def instantiate_template(template: MatchTemplate, params: dict):
    """Instantiates a template with the given parameters. Falls back to original values if not possible"""
    instance = copy.deepcopy(template.fragment)
    param_nodes = instance.find_data("__PARAMETER__")
    for param_node in param_nodes:
        node_type = cast(Token, param_node.children[0]).value
        param_id = cast(Token, param_node.children[1]).value
        param_value = params.get(param_id, template.original_values[param_id])
        param_node.set(node_type, param_value.children)
    return instance


def extract_params(concrete, context):
    """Extracts the parameters from the concrete context."""
    params = dict()
    if isinstance(context, Token):
        return params
    if context.data == "__PARAMETER__":
        node_type = context.children[0].value
        param_id = context.children[1].value
        if concrete.data != node_type:
            raise ValueError
        params[param_id] = concrete.to_tree()
        return params
    for concrete_child, context_child in zip(concrete.children, context.children):
        params |= extract_params(concrete_child, context_child)
    return params


def tree_to_templates(root, current_node=None, current_path=None):
    """
    Decomposes a tree into templates where each child of the root node is a fragment.
    """
    templates = []
    if current_node is None:
        current_node = root
    if current_path is None:
        current_path = []

    # create fragments from each child
    for cidx, child in enumerate(current_node.children):
        # exclude tokens and optional nodes
        if isinstance(child, Token) or child.data == "optional":
            continue
        # create a copy of the full tree for the context
        full_context = copy.deepcopy(root)

        # substitute the fragment node with a hole
        siblings = (
            current_node.children[:cidx]
            + [Token("HOLE", str(child.data))]
            + current_node.children[cidx + 1 :]
        )
        get_target_node(full_context, current_path).set(
            current_node.data, copy.deepcopy(siblings)
        )

        # create a template
        templates.append(
            MatchTemplate(
                full_context=full_context,
                full_fragment_path=current_path + [cidx],
                fragment=copy.deepcopy(child),
            )
        )

    # recurse into children
    for cidx, child in enumerate(current_node.children):
        if isinstance(child, Token):
            continue
        templates += tree_to_templates(
            root=root, current_node=child, current_path=current_path + [cidx]
        )

    return templates
