import copy
import unittest
import unittest.mock as mock

from lark import Tree, Token

from mlirmut.match_template import TreeView, MatchTemplate
from mlirmut.utils import get_target_node
from mlirmut.template_fuzzer import exact_tree_match, anchor_pattern


TOY_TREE = Tree(
    "level0",
    [
        Tree(
            "level1",
            [
                Tree(
                    "level2a",
                    [
                        Tree(
                            "level3a", [Token("TOKEN1", "abc"), Token("TOKEN2", "bcd")]
                        ),
                        Tree(
                            "level3b", [Token("TOKEN1", "123"), Token("TOKEN2", "456")]
                        ),
                        Tree(
                            "level3c",
                            [Token("TOKEN1", "abc123"), Token("TOKEN2", "bcd456")],
                        ),
                        Tree(
                            "level3d",
                            [Token("TOKEN1", "qwer"), Token("TOKEN2", "tyui")],
                        ),
                    ],
                ),
                Tree(
                    "level2b",
                    [
                        Tree(
                            "level3a",
                            [Token("TOKEN1", "asdf"), Token("TOKEN2", "ghjk")],
                        ),
                        Tree(
                            "level3b",
                            [Token("TOKEN1", "qwer"), Token("TOKEN2", "jk;l")],
                        ),
                        Tree(
                            "level3c",
                            [Token("TOKEN1", "poiu"), Token("TOKEN2", "zxcv")],
                        ),
                        Tree(
                            "level3d",
                            [Token("TOKEN1", "fghj"), Token("TOKEN2", "bnmk")],
                        ),
                    ],
                ),
            ],
        ),
        Tree(
            "level_alt1",
            [
                Tree(
                    "level_alt2a",
                    [
                        Tree(
                            "level_alt3a",
                            [Token("TOKEN1", "abc"), Token("TOKEN2", "bcd")],
                        ),
                        Tree(
                            "level_alt3b",
                            [Token("TOKEN1", "123"), Token("TOKEN2", "456")],
                        ),
                        Tree(
                            "level_alt3c",
                            [Token("TOKEN1", "abc123"), Token("TOKEN2", "bcd456")],
                        ),
                        Tree(
                            "level_alt3d",
                            [Token("TOKEN1", "qwer"), Token("TOKEN2", "tyui")],
                        ),
                    ],
                ),
            ],
        ),
    ],
)


def prune_tree(tree, path_to_prune):
    pruned_tree = copy.deepcopy(tree)
    parent = get_target_node(pruned_tree, path_to_prune[:-1])
    parent.set(
        parent.data,
        parent.children[: path_to_prune[-1]] + parent.children[path_to_prune[-1] + 1 :],
    )
    return pruned_tree


class TestGeneralization(unittest.TestCase):
    def setUp(self):
        full_context = copy.deepcopy(TOY_TREE)
        # substitute the fragment node with a hole
        siblings = (
            get_target_node(full_context, [0, 0]).children[:1]
            + [Token("HOLE", "HOLE")]
            + get_target_node(full_context, [0, 0]).children[2:]
        )
        get_target_node(full_context, [0, 0, 1]).set("level2a", copy.deepcopy(siblings))

        self.template = MatchTemplate(
            full_context=full_context,
            full_fragment_path=[0, 0, 1],
            fragment=copy.deepcopy(get_target_node(full_context, [0, 0, 1])),
        )

    def test_generalize_specialize_ancestors(self):
        self.assertEqual(self.template.masked_context, self.template.full_context)
        self.assertEqual(self.template.masked_fragment_path, [0, 0, 1])

        # generalizes out the oldest ancestor
        self.assertTrue(self.template.generalize_ancestors())
        self.assertEqual(
            self.template.masked_context,
            get_target_node(self.template.full_context, [0]),
        )
        self.assertEqual(self.template.masked_fragment_path, [0, 1])

        # generalizes out the next oldest ancestor
        self.assertTrue(self.template.generalize_ancestors())
        self.assertEqual(
            self.template.masked_context,
            get_target_node(self.template.full_context, [0, 0]),
        )
        self.assertEqual(self.template.masked_fragment_path, [1])

        # refuse to generalize out the direct parent
        self.assertFalse(self.template.generalize_ancestors())
        self.assertEqual(
            self.template.masked_context,
            get_target_node(self.template.full_context, [0, 0]),
        )
        self.assertEqual(self.template.masked_fragment_path, [1])

        # specializes in the grandparent
        self.assertTrue(self.template.specialize_ancestors())
        self.assertEqual(
            self.template.masked_context,
            get_target_node(self.template.full_context, [0]),
        )
        self.assertEqual(self.template.masked_fragment_path, [0, 1])

        # specializes in the great grandparent
        self.assertTrue(self.template.specialize_ancestors())
        self.assertEqual(
            self.template.masked_context,
            self.template.full_context,
        )
        self.assertEqual(self.template.masked_fragment_path, [0, 0, 1])

        # refuse to specialize further
        self.assertFalse(self.template.specialize_ancestors())
        self.assertEqual(
            self.template.masked_context,
            self.template.full_context,
        )
        self.assertEqual(self.template.masked_fragment_path, [0, 0, 1])

    def test_update_mask(self):
        self.assertEqual(self.template.masked_context, self.template.full_context)

        self.template.update_mask_and_path([0, 1], False)
        pruned_context = prune_tree(self.template.full_context, [0, 1])
        self.assertEqual(self.template.masked_context, pruned_context)

        self.template.update_mask_and_path([0, 0, 0], False)
        pruned_context = prune_tree(pruned_context, [0, 0, 0])
        self.assertEqual(self.template.masked_context, pruned_context)

        self.template.update_mask_and_path([0, 1], True)
        pruned_context = prune_tree(self.template.full_context, [0, 0, 0])
        self.assertEqual(self.template.masked_context, pruned_context)

        self.template.update_mask_and_path([0, 0, 0], True)
        self.assertEqual(self.template.masked_context, self.template.full_context)
    
    def test_update_path(self):
        self.assertEqual(self.template.masked_context, self.template.full_context)
        self.assertEqual(self.template.masked_fragment_path, self.template.full_fragment_path)

        self.template.update_mask_and_path([0, 0, 0], False)
        self.assertEqual(self.template.masked_fragment_path, [0, 0, 0])

    def test_generalize_specialize_nonancestors_exact(self):
        self.assertEqual(self.template.masked_context, self.template.full_context)

        mock_random = mock.Mock()
        target_branch = get_target_node(self.template.full_mask, [0, 1])
        mock_random.choice = lambda branches: [
            (branch_path, branch_data)
            for branch_path, branch_data in branches
            if branch_data == target_branch
        ][0]
        mock_random.random.return_value = 0.9
        self.assertTrue(self.template.generalize_nonancestors(rand=mock_random))
        pruned_context = prune_tree(self.template.full_context, [0, 1])
        self.assertEqual(self.template.masked_context, pruned_context)

        mock_random = mock.Mock()
        target_branch = get_target_node(self.template.full_mask, [0, 1])
        mock_random.choice = lambda branches: branches[0]
        mock_random.shuffle = lambda branches: branches
        self.assertTrue(self.template.specialize_nonancestors(rand=mock_random))
        self.assertEqual(self.template.masked_context, self.template.full_context)
    


if __name__ == "__main__":
    unittest.main()
