import tvm
import numpy as np

GOLDEN_DEVICE = "cuda"
GOLDEN_LIB = "cases/5/cuda.so"
BUGGY_DEVICE = "vulkan"
BUGGY_LIB = "cases/5/vulkan.so"
ATOL = 1e-3
MATMUL_ATOL = 1
RTOL = 1e-2
MATMUL_RTOL = 1e-1
SEQ_CANDIDATES = [1, 2, 4, 8, 16, 32]
BATCH_SIZE = [1, 2, 4, 8]
VOCAB_SIZES = [151936]

OP_RESULTS = {}


def random_array(shape, dtype):
    # For integer dtypes, use randint; for floats, use randn.
    if "int" in dtype or "uint" in dtype:
        if len(shape) == 0:
            return np.array(np.random.randint(0, 100), dtype=dtype).item()
        return np.random.randint(0, 100, size=shape).astype(dtype)

    if len(shape) == 0:
        return np.array(np.random.randn(*shape), dtype=dtype).item()

    return np.random.randn(*shape).astype(dtype)


def test_operator(op, types, golden_lib, buggy_lib):
    inputs = []
    for shape, dtype in types:
        inputs.append(random_array(shape, dtype))

    golden_dev = tvm.device(GOLDEN_DEVICE)
    buggy_dev = tvm.device(BUGGY_DEVICE)

    golden_inputs = [
        tvm.nd.array(arr, device=golden_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]
    buggy_inputs = [
        tvm.nd.array(arr, device=buggy_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]

    golden_func = golden_lib.get_function(op, query_imports=True)
    buggy_func = buggy_lib.get_function(op, query_imports=True)

    golden_func(*golden_inputs)
    buggy_func(*buggy_inputs)

    # Compare the output in the last buffer.
    golden_np = (
        golden_inputs[-1].numpy()
        if isinstance(golden_inputs[-1], tvm.nd.NDArray)
        else golden_inputs[-1]
    )
    buggy_np = (
        buggy_inputs[-1].numpy()
        if isinstance(buggy_inputs[-1], tvm.nd.NDArray)
        else buggy_inputs[-1]
    )

    def _is_matmul(op):
        return "matmul" in op or "gemm" in op or "gemv" in op

    atol = MATMUL_ATOL if _is_matmul(op) else ATOL
    rtol = MATMUL_RTOL if _is_matmul(op) else RTOL
    if op not in OP_RESULTS:
        OP_RESULTS[op] = (0, 0)
    if np.allclose(golden_np, buggy_np, atol=atol, rtol=rtol):
        print(f"✅ Operator {op} matches within tolerance (atol={atol}, rtol={rtol})")
        OP_RESULTS[op] = (OP_RESULTS[op][0] + 1, OP_RESULTS[op][1])
    else:
        diff = np.max(np.abs(golden_np - buggy_np))
        print(f"❌ Operator {op} mismatch! Max diff: {diff}")
        OP_RESULTS[op] = (OP_RESULTS[op][0], OP_RESULTS[op][1] + 1)


def main():
    golden_lib = tvm.runtime.load_module(GOLDEN_LIB)
    buggy_lib = tvm.runtime.load_module(BUGGY_LIB)

    # --- embed ---
    # fused_dequantize_take2
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        for vocab_size in VOCAB_SIZES:
            test_operator(
                "fused_dequantize_take2",
                [
                    ((vocab_size, 256), "uint32"),
                    ((vocab_size, 64), "float16"),
                    ((seq_len,), "int32"),
                    ((seq_len, 2048), "float16"),
                ],
                golden_lib,
                buggy_lib,
            )

    # --- prefill ---
    # rms_norm1
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "rms_norm1",
            [
                ((1, seq_len, 2048), "float16"),
                ((2048,), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize1_fused_NT_matmul7_add2
    #   T.Buffer((6144, 256), 'uint32')
    #   T.Buffer((6144, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((6144), 'float16')
    #   T.Buffer((1, seq_len, 6144), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_dequantize1_fused_NT_matmul7_add2",
            [
                ((6144, 256), "uint32"),
                ((6144, 64), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((6144,), "float16"),
                ((1, seq_len, 6144), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape9
    #   T.Buffer((1, seq_len, 6144), 'float16')
    #   T.Buffer((1, seq_len, 48, 128), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape9",
            [
                ((1, seq_len, 6144), "float16"),
                ((1, seq_len, 48, 128), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape10
    #   T.Buffer((1, seq_len, 48, 128), 'float16')
    #   T.Buffer((seq_len, 48, 128), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape10",
            [
                ((1, seq_len, 48, 128), "float16"),
                ((seq_len, 48, 128), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape11
    #   T.Buffer((seq_len, 16, 128), 'float16')
    #   T.Buffer((1, seq_len, 16, 128), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape11",
            [
                ((seq_len, 16, 128), "float16"),
                ((1, seq_len, 16, 128), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape12
    #   T.Buffer((1, seq_len, 16, 128), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape12",
            [
                ((1, seq_len, 16, 128), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize2_NT_matmul8
    #   T.Buffer((2048, 256), 'uint32')
    #   T.Buffer((2048, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_dequantize2_NT_matmul8",
            [
                ((2048, 256), "uint32"),
                ((2048, 64), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fuse_add_norm_prefill
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fuse_add_norm_prefill",
            [
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((2048,), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape13
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape13",
            [
                ((1, seq_len, 2048), "float16"),
                ((seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_NT_matmul2_cast
    #   T.Buffer((batch_size, 2048), 'float16')
    #   T.Buffer((60, 2048), 'float16')
    #   T.Buffer((batch_size, 60))

    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_NT_matmul2_cast",
            [
                ((batch_size, 2048), "float16"),
                ((60, 2048), "float16"),
                ((batch_size, 60), "float32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_softmax_cast1
    #   T.Buffer((batch_size, 60), 'float32')
    #   T.Buffer((batch_size, 60), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_softmax_cast1",
            [
                ((batch_size, 60), "float32"),
                ((batch_size, 60), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # top4_softmax
    #   T.Buffer((batch_size, 60), 'float16')
    #   T.Buffer((batch_size, 4), 'float16')
    #   T.Buffer((batch_size, 4), 'int32')
    for batch_size in BATCH_SIZE:
        test_operator(
            "top4_softmax",
            [
                ((batch_size, 60), "float16"),
                ((batch_size, 4), "float16"),
                ((batch_size, 4), "int32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_expert_mask_transpose
    #   T.Buffer((batch_size, 4), 'int32')
    #   T.Buffer((60, batch_size), 'int32')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_expert_mask_transpose",
            [
                ((batch_size, 4), "int32"),
                ((60, batch_size), "int32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape5
    #   T.Buffer((60, batch_size), 'int32')
    #   T.Buffer((batch_size * 60,), 'int32')
    for batch_size in BATCH_SIZE:
        test_operator(
            "reshape5",
            [
                ((60, batch_size), "int32"),
                ((batch_size * 60,), "int32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # gpu_2d_continuous_cumsum2
    #   T.Buffer((m, n), 'int32')
    #   T.Buffer((m, n), 'int32')
    for m in [1, 2, 4, 8, 16, 32]:
        for n in [64, 256, 512, 2352]:
            test_operator(
                "gpu_2d_continuous_cumsum2",
                [
                    ((m, n), "int32"),
                    ((m, n), "int32"),
                    ((m, n), "int32"),
                ],
                golden_lib,
                buggy_lib,
            )

    # get_indices
    #   T.Buffer((cumsum_len,), 'int32')
    #   T.Buffer((batch_size, 4), 'int32')
    #   T.Buffer((batch_size * 4,), 'int32')
    #   T.Buffer((batch_size * 4,), 'int32')

    for cumsum_len in [1, 2, 4, 8, 16]:
        for batch_size in BATCH_SIZE:
            test_operator(
                "get_indices",
                [
                    ((cumsum_len,), "int32"),
                    ((batch_size, 4), "int32"),
                    ((batch_size * 4,), "int32"),
                    ((batch_size * 4,), "int32"),
                ],
                golden_lib,
                buggy_lib,
            )

    # get_expert_instance_indptr
    #   T.Buffer((batch_size * 60,), 'int32')
    #   T.Buffer((61,), 'int32')
    #   T.int64
    #   skipped

    # take
    #   T.Buffer((batch_size, 2048), 'float16')
    #   T.Buffer((batch_size * 4,), 'int32')
    #   T.Buffer((batch_size * 4, 2048), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "take",
            [
                ((batch_size, 2048), "float16"),
                ((batch_size * 4,), "int32"),
                ((batch_size * 4, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # dequantize_group_gemm
    #   T.Buffer((batch_size, 2048), 'float16')
    #   T.Buffer((60, 2816, 256), 'uint32')
    #   T.Buffer((60, 2816, 64), 'float16')
    #   T.Buffer((61,), 'int32')
    #   T.Buffer((batch_size, 2816), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "dequantize_group_gemm",
            [
                ((batch_size, 2048), "float16"),
                ((60, 2816, 256), "uint32"),
                ((60, 2816, 64), "float16"),
                ((61,), "int32"),
                ((batch_size, 2816), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_split_silu_multiply
    #   T.Buffer((batch_size * 4, seq_len, 2816), 'float16')
    #   T.Buffer((batch_size * 4, seq_len, 1408), 'float16')
    # for batch_size in BATCH_SIZE:
    #     for seq_len in SEQ_CANDIDATES:
    #         test_operator(
    #             "fused_split_silu_multiply",
    #             [
    #                 ((batch_size * 4, seq_len, 2816), 'float16'),
    #                 ((batch_size * 4, seq_len, 1408), 'float16'),
    #             ],
    #             golden_lib,
    #             buggy_lib,
    #         )

    # dequantize_group_gemm1
    #   T.Buffer((batch_size, 1408), 'float16')
    #   T.Buffer((60, 2048, 176), 'uint32')
    #   T.Buffer((60, 2048, 44), 'float16')
    #   T.Buffer((61,), 'int32')
    #   T.Buffer((batch_size, 2048), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "dequantize_group_gemm1",
            [
                ((batch_size, 1408), "float16"),
                ((60, 2048, 176), "uint32"),
                ((60, 2048, 44), "float16"),
                ((61,), "int32"),
                ((batch_size, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # scatter_output
    #   T.Buffer((seq_len, 2048), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "scatter_output",
            [
                ((seq_len, 2048), "float16"),
                ((seq_len,), "int32"),
                ((seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape6
    #   T.Buffer((batch_size, 4), 'float16')
    #   T.Buffer((batch_size, 4, 1), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "reshape6",
            [
                ((batch_size, 4), "float16"),
                ((batch_size, 4, 1), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape7
    #   T.Buffer((batch_size * 4, 2048), 'float16')
    #   T.Buffer((batch_size, 4, 2048), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "reshape7",
            [
                ((batch_size * 4, 2048), "float16"),
                ((batch_size, 4, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_multiply1_sum
    #   T.Buffer((batch_size, 4, 2048), 'float16')
    #   T.Buffer((batch_size, 4, 1), 'float16')
    #   T.Buffer((batch_size, 2048), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_multiply1_sum",
            [
                ((batch_size, 4, 2048), "float16"),
                ((batch_size, 4, 1), "float16"),
                ((batch_size, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize3_NT_matmul3
    #   T.Buffer((11264, 256), 'uint32')
    #   T.Buffer((11264, 64), 'float16')
    #   T.Buffer((batch_size, 2048), 'float16')
    #   T.Buffer((batch_size, 11264), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_dequantize3_NT_matmul3",
            [
                ((11264, 256), "uint32"),
                ((11264, 64), "float16"),
                ((batch_size, 2048), "float16"),
                ((batch_size, 11264), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_split1_silu1_multiply2
    #   T.Buffer((batch_size, 11264), 'float16')
    #   T.Buffer((batch_size, 5632), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_split1_silu1_multiply2",
            [
                ((batch_size, 11264), "float16"),
                ((batch_size, 5632), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_NT_matmul5_tir_sigmoid
    #   T.Buffer((batch_size, 2048), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((batch_size, 1), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_NT_matmul5_tir_sigmoid",
            [
                ((batch_size, 2048), "float16"),
                ((1, 2048), "float16"),
                ((batch_size, 1), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize4_fused_NT_matmul4_multiply3_add1
    #   T.Buffer((2048, 704), 'uint32')
    #   T.Buffer((2048, 176), 'float16')
    #   T.Buffer((batch_size, 5632), 'float16')
    #   T.Buffer((batch_size, 1), 'float16')
    #   T.Buffer((batch_size, 2048), 'float16')
    #   T.Buffer((batch_size, 2048), 'float16')
    for batch_size in BATCH_SIZE:
        test_operator(
            "fused_dequantize4_fused_NT_matmul4_multiply3_add1",
            [
                ((2048, 704), "uint32"),
                ((2048, 176), "float16"),
                ((batch_size, 5632), "float16"),
                ((batch_size, 1), "float16"),
                ((batch_size, 2048), "float16"),
                ((batch_size, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape14
    #   T.Buffer((seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape14",
            [
                ((seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # index
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "index",
            [
                ((1, seq_len, 2048), "float16"),
                ((1, 1, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize_fused_NT_matmul16_cast6
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, vocab_size), 'float32')
    for vocab_size in VOCAB_SIZES:
        test_operator(
            "fused_dequantize_fused_NT_matmul16_cast6",
            [
                ((vocab_size, 256), "uint32"),
                ((vocab_size, 64), "float16"),
                ((1, 1, 2048), "float16"),
                ((1, 1, vocab_size), "float32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # --- decode ---
    # rms_norm2
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "rms_norm2",
        [
            ((1, 1, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_fused_NT_matmul10_add3
    #   T.Buffer((6144, 256), 'uint32')
    #   T.Buffer((6144, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((6144,), 'float16')
    #   T.Buffer((1, 1, 6144), 'float16')
    test_operator(
        "fused_dequantize1_fused_NT_matmul10_add3",
        [
            ((6144, 256), "uint32"),
            ((6144, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((6144,), "float16"),
            ((1, 1, 6144), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape15_reshape16
    #   T.Buffer((1, 1, 6144), 'float16')
    #   T.Buffer((1, 48, 128), 'float16')
    test_operator(
        "fused_reshape15_reshape16",
        [
            ((1, 1, 6144), "float16"),
            ((1, 48, 128), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape17_reshape18
    #   T.Buffer((1, 16, 128), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_reshape17_reshape18",
        [
            ((1, 16, 128), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_NT_matmul11
    #   T.Buffer((2048, 256), 'uint32')
    #   T.Buffer((2048, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_dequantize2_NT_matmul11",
        [
            ((2048, 256), "uint32"),
            ((2048, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fuse_add_norm_prefill (decode)
    test_operator(
        "fuse_add_norm_prefill",
        [
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape19
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    test_operator(
        "fused_reshape19",
        [
            ((1, 1, 2048), "float16"),
            ((1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_NT_matmul12_cast4
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((60, 2048), 'float16')
    #   T.Buffer((1, 60), 'float32')
    test_operator(
        "fused_NT_matmul12_cast4",
        [
            ((1, 2048), "float16"),
            ((60, 2048), "float16"),
            ((1, 60), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_softmax1_cast5
    #   T.Buffer((1, 60), 'float32')
    #   T.Buffer((1, 60), 'float16')
    test_operator(
        "fused_softmax1_cast5",
        [
            ((1, 60), "float32"),
            ((1, 60), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # top4_softmax
    #   T.Buffer((batch_size, 60), 'float16')
    #   T.Buffer((batch_size, 4), 'float16')
    #   T.Buffer((batch_size, 4), 'int32')
    for batch_size in BATCH_SIZE:
        test_operator(
            "top4_softmax",
            [
                ((batch_size, 60), "float16"),
                ((batch_size, 4), "float16"),
                ((batch_size, 4), "int32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_reshape20
    #   T.Buffer((1, 4), 'float16')
    #   T.Buffer((1, 4, 1), 'float16')
    test_operator(
        "fused_reshape20",
        [
            ((1, 4), "float16"),
            ((1, 4, 1), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # moe_dequantize_gemv
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((60, 2816, 256), 'uint32')
    #   T.Buffer((60, 2816, 64), 'float16')
    #   T.Buffer((1, 4), 'int32')
    #   T.Buffer((4, 2816), 'float16')
    # test_operator(
    #     "moe_dequantize_gemv",
    #     [
    #         ((1, 2048), "float16"),
    #         ((60, 2816, 256), "uint32"),
    #         ((60, 2816, 64), "float16"),
    #         ((1, 4), "int32"),
    #         ((4, 2816), "float16"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # fused_split2_silu2_multiply4
    #   T.Buffer((4, 2816), 'float16')
    #   T.Buffer((4, 1408), 'float16')
    test_operator(
        "fused_split2_silu2_multiply4",
        [
            ((4, 2816), "float16"),
            ((4, 1408), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # moe_dequantize_gemv1
    #   T.Buffer((4, 1408), 'float16')
    #   T.Buffer((60, 2048, 176), 'uint32')
    #   T.Buffer((60, 2048, 44), 'float16')
    # test_operator(
    #     "moe_dequantize_gemv1",
    #     [
    #         ((4, 1408), "float16"),
    #         ((60, 2048, 176), "uint32"),
    #         ((60, 2048, 44), "float16"),
    #         ((1, 4), "int32"),
    #         ((4, 2048), "float16"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # reshape21
    #   T.Buffer((4, 2048), 'float16')
    #   T.Buffer((1, 4, 2048), 'float16')
    test_operator(
        "reshape21",
        [
            ((4, 2048), "float16"),
            ((1, 4, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_multiply5_sum1
    #   T.Buffer((1, 4, 2048), 'float16')
    #   T.Buffer((1, 4, 1), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    test_operator(
        "fused_multiply5_sum1",
        [
            ((1, 4, 2048), "float16"),
            ((1, 4, 1), "float16"),
            ((1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul13
    #   T.Buffer((11264, 256), 'uint32')
    #   T.Buffer((11264, 64), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((1, 11264), 'float16')
    test_operator(
        "fused_dequantize3_NT_matmul13",
        [
            ((11264, 256), "uint32"),
            ((11264, 64), "float16"),
            ((1, 2048), "float16"),
            ((1, 11264), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split3_silu3_multiply6
    #   T.Buffer((1, 11264), 'float16')
    #   T.Buffer((1, 5632), 'float16')
    test_operator(
        "fused_split3_silu3_multiply6",
        [
            ((1, 11264), "float16"),
            ((1, 5632), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_NT_matmul15_tir_sigmoid1
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((1, 1), 'float16')
    test_operator(
        "fused_NT_matmul15_tir_sigmoid1",
        [
            ((1, 2048), "float16"),
            ((1, 2048), "float16"),
            ((1, 1), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_fused_NT_matmul14_multiply7_add4
    #   T.Buffer((2048, 704), 'uint32')
    #   T.Buffer((2048, 176), 'float16')
    #   T.Buffer((1, 5632), 'float16')
    #   T.Buffer((1, 1), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((1, 2048), 'float16')
    test_operator(
        "fused_dequantize4_fused_NT_matmul14_multiply7_add4",
        [
            ((2048, 704), "uint32"),
            ((2048, 176), "float16"),
            ((1, 5632), "float16"),
            ((1, 1), "float16"),
            ((1, 2048), "float16"),
            ((1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape22
    #   T.Buffer((1, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "reshape22",
        [
            ((1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize_fused_NT_matmul16_cast6
    #   T.Buffer((151936, 256), 'uint32')
    #   T.Buffer((151936, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 151936), 'float32')
    test_operator(
        "fused_dequantize_fused_NT_matmul16_cast6",
        [
            ((151936, 256), "uint32"),
            ((151936, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 151936), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )


    # # --- attn_kernels ---
    # # batch_decode_paged_kv
    # #   T.int32
    # #   T.Buffer((B, 32, 128), 'float16')
    # #   T.Buffer((max_num_pages, 2, 8, 16, 128), 'float16')
    # #   T.Buffer((B + 1,), 'int32')
    # #   T.Buffer((nnz_pages,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B, 32, 64), 'float16')
    # #   T.Buffer((B, 32))
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_decode_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((2, 32, 64), "float16"),
    #         ((4, 2, 8, 16, 64), "float16"),
    #         ((3,), "int32"),
    #         ((10,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((2, 32, 64), "float16"),
    #         ((2, 32), "float16"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_paged_kv
    # #   T.int32
    # #   T.Buffer((total_len, 32, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((max_num_pages,2,8,16,64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((nnz_pages,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((total_len,), 'int32')
    # #   T.Buffer((total_len,32,64), 'float16')
    # #   T.Buffer((total_len,32))
    # #   T.int32
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_prefill_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((10, 32, 64), "float16"),
    #         ((3,), "int32"),
    #         ((4, 2, 8, 16, 64), "float16"),
    #         ((3,), "int32"),
    #         ((10,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((10,), "int32"),
    #         ((10, 32, 64), "float16"),
    #         ((10, 32), "float32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_ragged_kv
    # #   T.Buffer((qo_len, 32, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((kv_len, 8, 64), 'float16')
    # #   T.Buffer((kv_len, 8, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((qo_len,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((qo_len, 32, 64), 'float16')
    # #   T.Buffer((qo_len, 32))
    # #   T.int32
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_prefill_ragged_kv",
    #     [
    #         ((5, 32, 64), "float16"),
    #         ((3,), "int32"),
    #         ((7, 8, 64), "float16"),
    #         ((7, 8, 64), "float16"),
    #         ((3,), "int32"),
    #         ((5,), "int32"),
    #         ((2,), "int32"),
    #         ((5, 32, 64), "float16"),
    #         ((5, 32), "float32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # fused_rope
    #   T.Buffer((seq_len, 20, 128), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 16, 128), 'float16')
    #   T.Buffer((seq_len, 2, 128), 'float16')
    #   T.Buffer((seq_len, 2, 128), 'float16')
    #   T.int32
    # for seq_len in SEQ_CANDIDATES:
    #     test_operator(
    #         "fused_rope",
    #         [
    #             ((seq_len, 20, 128), "float16"),
    #             ((seq_len,), "int32"),
    #             ((seq_len, 16, 128), "float16"),
    #             ((seq_len, 2, 128), "float16"),
    #             ((seq_len, 2, 128), "float16"),
    #             ((), "int32"),
    #         ],
    #         golden_lib,
    #         buggy_lib,
    #     )

    print("Detected buggy operators: ")
    for op, (corrects, fails) in OP_RESULTS.items():
        if fails != 0:
            print(f"Operator {op} passes {corrects} tests but fails {fails} tests")


if __name__ == "__main__":
    main()
