import os
import pickle
from argparse import ArgumentParser

from lark import UnexpectedCharacters, Token
from tqdm import tqdm

from mlirmut.match_template import MatchTemplate, TreeMask

def recursively_generalize(
    mask: TreeMask | Token, current_path: list[int], fragment_path: list[int]
):
    if isinstance(mask, Token):
        return
    if current_path != fragment_path[:-1] and current_path != fragment_path:
        mask.visible = False
    for child_idx, child in enumerate(mask.children):
        recursively_generalize(child, current_path + [child_idx], fragment_path)


def generalize_all(template: MatchTemplate):
    recursively_generalize(
        mask=template.full_mask,
        current_path=[],
        fragment_path=template.full_fragment_path,
    )
    template.masked_context = template.full_mask.mask_tree(root=template.full_context, root_path=template.full_fragment_path)
    template.masked_fragment_path = template.full_mask.mask_path(template.full_fragment_path)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("input_dir")
    parser.add_argument("output_dir")
    args = parser.parse_args()

    for dirent in os.scandir(args.input_dir):
        try:
            os.mkdir(os.path.join(args.output_dir, dirent.name))
        except FileExistsError:
            pass
        for fileent in os.scandir(dirent.path):
            with open(fileent.path, "rb") as fin:
                template = pickle.load(fin)
            generalize_all(template)
            with open(os.path.join(args.output_dir, dirent.name, fileent.name), "wb") as fout:
                pickle.dump(template, fout)