import unittest

from lark import Tree, Token

from mlirmut.match_template import TreeMask
from mlirmut.utils import get_target_node
import mlirmut.scripts.reset_masks as reset_masks

TOY_TREE = Tree(
    "level0",
    [
        Tree(
            "level1",
            [
                Tree(
                    "level2a",
                    [
                        Tree(
                            "level3a", [Token("TOKEN1", "abc"), Token("TOKEN2", "bcd")]
                        ),
                        Tree(
                            "level3b", [Token("TOKEN1", "123"), Token("TOKEN2", "456")]
                        ),
                        Tree(
                            "level3c",
                            [Token("TOKEN1", "abc123"), Token("TOKEN2", "bcd456")],
                        ),
                        Tree(
                            "level3d",
                            [Token("TOKEN1", "qwer"), Token("TOKEN2", "tyui")],
                        ),
                    ],
                ),
                Tree(
                    "level2b",
                    [
                        Tree(
                            "level3a",
                            [Token("TOKEN1", "asdf"), Token("TOKEN2", "ghjk")],
                        ),
                        Tree(
                            "level3b",
                            [Token("TOKEN1", "qwer"), Token("TOKEN2", "jk;l")],
                        ),
                        Tree(
                            "level3c",
                            [Token("TOKEN1", "poiu"), Token("TOKEN2", "zxcv")],
                        ),
                        Tree(
                            "level3d",
                            [Token("TOKEN1", "fghj"), Token("TOKEN2", "bnmk")],
                        ),
                    ],
                ),
            ],
        ),
    ]
)

class TestResetMask(unittest.TestCase):
    def setUp(self):
        self.mask = TreeMask.init_mask(TOY_TREE)
    def test_reset_mask(self):
        reset_masks.recursively_generalize(self.mask, [], [0,0,1])

        self.assertTrue(get_target_node(self.mask, [0,0,1]).visible)
        self.assertTrue(get_target_node(self.mask, [0,0]).visible)

        self.assertFalse(get_target_node(self.mask, []).visible)
        self.assertFalse(get_target_node(self.mask, [0]).visible)
        self.assertFalse(get_target_node(self.mask, [0,1]).visible)
        self.assertFalse(get_target_node(self.mask, [0,1,0]).visible)
        self.assertFalse(get_target_node(self.mask, [0,1,1]).visible)
        self.assertFalse(get_target_node(self.mask, [0,1,2]).visible)
        self.assertFalse(get_target_node(self.mask, [0,1,3]).visible)
        self.assertFalse(get_target_node(self.mask, [0,0,0]).visible)
        self.assertFalse(get_target_node(self.mask, [0,0,2]).visible)
        self.assertFalse(get_target_node(self.mask, [0,0,3]).visible)
 
if __name__ == "__main__":
    unittest.main()