# pylint: disable=missing-docstring, dangerous-default-value
import argparse
import ast
from pathlib import Path
from typing import List

ATTN_KERNELS = [
    "batch_decode_paged_kv",
    "batch_prefill_paged_kv",
    "batch_prefill_ragged_kv",
    "fused_rope",
]


def is_call_tir(node: ast.Call):
    return ast.unparse(node.func) == "R.call_tir"


def is_call_dps_packed(node: ast.Call):
    return ast.unparse(node.func) == "R.call_dps_packed"


def get_called_functions(node: ast.FunctionDef):
    tir_funcs = []
    dps_packed_funcs = []
    for child in ast.walk(node):
        if isinstance(child, ast.Call) and is_call_tir(child):
            func = child.args[0]
            assert isinstance(func, ast.Attribute) and func.value.id == "cls"
            tir_func_name = func.attr
            if tir_func_name not in tir_funcs:
                tir_funcs.append(tir_func_name)
        elif isinstance(child, ast.Call) and is_call_dps_packed(child):
            func_str = child.args[0]
            assert isinstance(func_str, ast.Constant)
            func_name = func_str.value
            if func_name not in dps_packed_funcs:
                dps_packed_funcs.append(func_name)
    return tir_funcs, dps_packed_funcs


def get_function_params(node: ast.FunctionDef):
    params = []
    match_buffers = {}
    for stmt in node.body:
        for child in ast.walk(stmt):
            if isinstance(child, ast.Call) and ast.unparse(child.func) == "T.match_buffer":
                arg_name = child.args[0].id
                new_attr = f"T.Buffer({', '.join([ast.unparse(arg) for arg in child.args[1:]])})"
                match_buffers[arg_name] = new_attr
    for param in node.args.args:
        assert param.annotation is not None
        dump = ast.unparse(param.annotation)
        if dump == "T.handle":
            params.append(match_buffers[param.arg])
        else:
            params.append(dump)
    return params


def main(file_path: Path, names: List[str] = ["embed", "prefill", "decode"]):
    with open(file_path, "r", encoding="utf-8") as f:
        code = f.read()
    tree = ast.parse(code)
    functions = {name: [] for name in names}
    functions["attn_kernels"] = ATTN_KERNELS
    all_tir_funcs = set()
    tir_funcs_params = {}
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and node.name in names:
            tir_funcs, _ = get_called_functions(node)
            functions[node.name] = tir_funcs
            all_tir_funcs.update(tir_funcs)

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef) and (
            node.name in all_tir_funcs or node.name in ATTN_KERNELS
        ):
            params = get_function_params(node)
            tir_funcs_params[node.name] = params

    for name in functions:
        print(f"============={name}=============")
        for tir_func in functions[name]:
            print(tir_func)
            print(*tir_funcs_params[tir_func], sep="\n")
            print("--------------------------------")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("path", type=str)
    args = parser.parse_args()
    main(Path(args.path))
