from typing import Self, cast, TypeAlias
from dataclasses import dataclass

from lark import ParseTree, Tree
from lark import Token


TreePath: TypeAlias = list[int]


@dataclass(slots=True)
class TreeView:
    original: ParseTree | Token
    data: str | None
    children: list[Self] | None

    @staticmethod
    def from_tree(node: ParseTree | Token):
        if isinstance(node, Token):
            return TreeView(node, None, None)
        return TreeView(
            node, node.data, [TreeView.from_tree(child) for child in node.children]  # type: ignore
        )

    def to_tree(self):
        if isinstance(self.original, Token):
            return self.original
        return ParseTree(self.data, [child.to_tree() for child in self.children])


@dataclass(slots=True)
class Anchoring:
    tree_view: TreeView
    path: TreePath


def exact_tree_match(tree: ParseTree | Token, pattern: ParseTree | Token):
    """Check if a pattern is embedded within the ParseTree. Matching is inclusive of
    extra nodes between the nodes defined by the pattern. Returns a view of the matched nodes if possible, otherwise returns None
    """
    if isinstance(pattern, Token):
        # if the pattern is a hole, then we match the node type
        if pattern.type == "HOLE" and isinstance(tree, Tree) and tree.data == str(cast(Token, pattern).value):
            return TreeView.from_tree(tree)
        # if the pattern is a regular lexical token,
        # then we match the values
        if isinstance(tree, Token) and str(tree.value) == str(pattern.value):
            return TreeView.from_tree(tree)
        # exit early here since the pattern has no children to check
        else:
            return None

    # the pattern must be a ParseTree after this point
    if isinstance(tree, Token):
        # the tree must be a ParseTree for it to be a valid match
        return None

    # if the pattern is a parameter,
    # then we only need to match the node type
    if pattern.data == "__PARAMETER__" and tree.data == str(
        cast(Token, pattern.children[0]).value
    ):
        return TreeView.from_tree(tree)

    # check if the tree matches the pattern
    if tree.data == pattern.data and len(pattern.children) <= len(tree.children):
        # check if the children also match
        matched_children = []
        tcidx = 0
        pcidx = 0
        while tcidx < len(tree.children) and pcidx < len(pattern.children):
            matched_subtree = exact_tree_match(
                tree.children[tcidx], pattern.children[pcidx]
            )
            if matched_subtree is not None:
                matched_children.append(matched_subtree)
                tcidx += 1
                pcidx += 1
            else:
                tcidx += 1
        if pcidx < len(pattern.children):
            # if we exhausted the tree children before the pattern children,
            # then we didn't match
            return None
        return TreeView(tree, tree.data, matched_children)

    return None


def anchor_pattern(tree: ParseTree | Token, pattern: ParseTree | Token, anchor_path: TreePath | None=None) -> Anchoring | None:
    if anchor_path is None:
        anchor_path = []

    # try anchoring with tree as the root node
    tree_view = exact_tree_match(tree, pattern)
    if tree_view is not None:
        return Anchoring(
            path=anchor_path,
            tree_view=tree_view,
        )

    if isinstance(tree, Token):
        return None

    # try anchoring again, but with one of the children as the root node
    for cidx, child in enumerate(tree.children):
        result = anchor_pattern(child, pattern, anchor_path + [cidx])
        if result is not None:
            return result

    return None
