import unittest

from lark import Tree, Token

from mlirmut.match_template import TreeView
from mlirmut.template_fuzzer import exact_tree_match, anchor_pattern


TOY_TREE = Tree(
    "level1",
    [
        Tree(
            "level2a",
            [
                Tree("level3a", [Token("TOKEN1", "abc"), Token("TOKEN2", "bcd")]),
                Tree("level3b", [Token("TOKEN1", "123"), Token("TOKEN2", "456")]),
                Tree("level3c", [Token("TOKEN1", "abc123"), Token("TOKEN2", "bcd456")]),
                Tree("level3d", [Token("TOKEN1", "qwer"), Token("TOKEN2", "tyui")]),
            ],
        ),
        Tree(
            "level2b",
            [
                Tree("level3a", [Token("TOKEN1", "asdf"), Token("TOKEN2", "ghjk")]),
                Tree("level3b", [Token("TOKEN1", "qwer"), Token("TOKEN2", "jk;l")]),
                Tree("level3c", [Token("TOKEN1", "poiu"), Token("TOKEN2", "zxcv")]),
                Tree("level3d", [Token("TOKEN1", "fghj"), Token("TOKEN2", "bnmk")]),
            ],
        )
    ],
)

TOY_PATTERN = Tree(
    "level2a",
    [
        Tree("level3b", [Token("TOKEN1", "123"), Token("TOKEN2", "456")]),
        Tree("level3d", [Token("TOKEN1", "qwer"), Token("TOKEN2", "tyui")]),
    ],
)


class TestTreeMatching(unittest.TestCase):
    def test_exact_tree_match(self):
        matched_tree = exact_tree_match(TOY_TREE.children[0], TOY_PATTERN)
        self.assertIsNotNone(matched_tree)
        self.assertEqual(matched_tree.to_tree(), TOY_PATTERN)
    
    def test_anchor_pattern(self):
        anchoring = anchor_pattern(TOY_TREE, TOY_PATTERN)
        self.assertIsNotNone(anchoring)
        anchor_path, tree_view = anchoring
        self.assertEqual(anchor_path, [0])
        self.assertEqual(tree_view.to_tree(), TOY_PATTERN)
    
if __name__ == "__main__":
    unittest.main()