import tvm
import numpy as np

GOLDEN_DEVICE = "cuda"
GOLDEN_LIB = "cases/6/cuda.so"
BUGGY_DEVICE = "vulkan"
BUGGY_LIB = "cases/6/vulkan.so"
ATOL = 1e-3
MATMUL_ATOL = 1
RTOL = 1e-2
MATMUL_RTOL = 1e-1
SEQ_CANDIDATES = [1, 2, 4, 8, 16, 32]
VOCAB_SIZES = [151936]

OP_RESULTS = {}


def random_array(shape, dtype):
    # For integer dtypes, use randint; for floats, use randn.
    if "int" in dtype or "uint" in dtype:
        if len(shape) == 0:
            return np.array(np.random.randint(0, 100), dtype=dtype).item()
        return np.random.randint(0, 100, size=shape).astype(dtype)

    if len(shape) == 0:
        return np.array(np.random.randn(*shape), dtype=dtype).item()

    return np.random.randn(*shape).astype(dtype)


def test_operator(op, types, golden_lib, buggy_lib):
    inputs = []
    for shape, dtype in types:
        inputs.append(random_array(shape, dtype))

    golden_dev = tvm.device(GOLDEN_DEVICE)
    buggy_dev = tvm.device(BUGGY_DEVICE)

    golden_inputs = [
        tvm.nd.array(arr, device=golden_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]
    buggy_inputs = [
        tvm.nd.array(arr, device=buggy_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]

    golden_func = golden_lib.get_function(op, query_imports=True)
    buggy_func = buggy_lib.get_function(op, query_imports=True)

    golden_func(*golden_inputs)
    buggy_func(*buggy_inputs)

    # Compare the output in the last buffer.
    golden_np = (
        golden_inputs[-1].numpy()
        if isinstance(golden_inputs[-1], tvm.nd.NDArray)
        else golden_inputs[-1]
    )
    buggy_np = (
        buggy_inputs[-1].numpy()
        if isinstance(buggy_inputs[-1], tvm.nd.NDArray)
        else buggy_inputs[-1]
    )

    atol = MATMUL_ATOL if "matmul" in op else ATOL
    rtol = MATMUL_RTOL if "matmul" in op else RTOL
    if op not in OP_RESULTS:
        OP_RESULTS[op] = (0, 0)
    if np.allclose(golden_np, buggy_np, atol=atol, rtol=rtol):
        print(f"✅ Operator {op} matches within tolerance (atol={atol}, rtol={rtol})")
        OP_RESULTS[op] = (OP_RESULTS[op][0] + 1, OP_RESULTS[op][1])
    else:
        diff = np.max(np.abs(golden_np - buggy_np))
        print(f"❌ Operator {op} mismatch! Max diff: {diff}")
        OP_RESULTS[op] = (OP_RESULTS[op][0], OP_RESULTS[op][1] + 1)


def main():
    golden_lib = tvm.runtime.load_module(GOLDEN_LIB)
    buggy_lib = tvm.runtime.load_module(BUGGY_LIB)

    # --- embed ---
    # fused_dequantize_take1
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        for vocab_size in VOCAB_SIZES:
            test_operator(
                "fused_dequantize_take1",
                [
                    ((vocab_size, 256), "uint32"),
                    ((vocab_size, 64), "float16"),
                    ((seq_len,), "int32"),
                    ((seq_len, 2048), "float16"),
                ],
                golden_lib,
                buggy_lib,
            )

    # --- prefill ---
    # rms_norm1
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "rms_norm1",
            [
                ((1, seq_len, 2048), "float16"),
                ((2048,), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize1_fused_NT_matmul5_add1
    #   T.Buffer((2560, 256), 'uint32')
    #   T.Buffer((2560, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2560), 'float16')
    #   T.Buffer((1, seq_len, 2560), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_dequantize1_fused_NT_matmul5_add1",
            [
                ((2560, 256), "uint32"),
                ((2560, 64), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((2560,), "float16"),
                ((1, seq_len, 2560), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape4
    #   T.Buffer((1, seq_len, 2560), 'float16')
    #   T.Buffer((1, seq_len, 20, 128), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape4",
            [
                ((1, seq_len, 2560), "float16"),
                ((1, seq_len, 20, 128), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape5
    #   T.Buffer((1, seq_len, 20, 128), 'float16')
    #   T.Buffer((seq_len, 20, 128), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape5",
            [
                ((1, seq_len, 20, 128), "float16"),
                ((seq_len, 20, 128), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape6
    #   T.Buffer((seq_len, 16, 128), 'float16')
    #   T.Buffer((1, seq_len, 16, 128), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape6",
            [
                ((seq_len, 16, 128), "float16"),
                ((1, seq_len, 16, 128), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # reshape7
    #   T.Buffer((1, seq_len, 16, 128), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "reshape7",
            [
                ((1, seq_len, 16, 128), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize2_NT_matmul6
    #   T.Buffer((2048, 256), 'uint32')
    #   T.Buffer((2048, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_dequantize2_NT_matmul6",
            [
                ((2048, 256), "uint32"),
                ((2048, 64), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fuse_add_norm_prefill
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fuse_add_norm_prefill",
            [
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((2048,), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize3_NT_matmul7
    #   T.Buffer((16384, 256), 'uint32')
    #   T.Buffer((16384, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 16384), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_dequantize3_NT_matmul7",
            [
                ((22016, 256), "uint32"),
                ((22016, 64), "float16"),
                ((1, seq_len, 2048), "float16"),
                ((1, seq_len, 22016), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_split1_silu1_multiply1
    #   T.Buffer((1, seq_len, 22016), 'float16')
    #   T.Buffer((1, seq_len, 11008), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_split1_silu1_multiply1",
            [
                ((1, seq_len, 22016), "float16"),
                ((1, seq_len, 11008), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize4_NT_matmul8
    #   T.Buffer((2048, 1376), 'uint32')
    #   T.Buffer((2048, 344), 'float16')
    #   T.Buffer((1, seq_len, 11008), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_dequantize4_NT_matmul8",
            [
                ((2048, 1376), "uint32"),
                ((2048, 344), "float16"),
                ((1, seq_len, 11008), "float16"),
                ((1, seq_len, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # index
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "index",
            [
                ((1, seq_len, 2048), "float16"),
                ((1, 1, 2048), "float16"),
            ],
            golden_lib,
            buggy_lib,
        )

    # fused_dequantize_NT_matmul14
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, vocab_size), 'float32')
    for vocab_size in VOCAB_SIZES:
        test_operator(
            "fused_dequantize_NT_matmul14",
            [
                ((vocab_size, 256), "uint32"),
                ((vocab_size, 64), "float16"),
                ((1, 1, 2048), "float16"),
                ((1, 1, vocab_size), "float32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # --- decode ---
    # rms_norm2
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "rms_norm2",
        [
            ((1, 1, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_NT_matmul10
    #   T.Buffer((2560, 256), 'uint32')
    #   T.Buffer((2560, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((2560,), 'float16')
    #   T.Buffer((1, 1, 2560), 'float16')
    test_operator(
        "fused_dequantize1_fused_NT_matmul10_add2",
        [
            ((2560, 256), "uint32"),
            ((2560, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((2560,), "float16"),
            ((1, 1, 2560), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape8_reshape9
    #   T.Buffer((1, 1, 2560), 'float16')
    #   T.Buffer((1, 20, 128), 'float16')
    test_operator(
        "fused_reshape8_reshape9",
        [
            ((1, 1, 2560), "float16"),
            ((1, 20, 128), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape10_reshape11
    #   T.Buffer((1, 20, 128), 'float16')
    #   T.Buffer((1, 1, 4096), 'float16')
    test_operator(
        "fused_reshape10_reshape11",
        [
            ((1, 16, 128), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_NT_matmul11
    #   T.Buffer((2048, 256), 'uint32')
    #   T.Buffer((2048, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_dequantize2_NT_matmul11",
        [
            ((2048, 256), "uint32"),
            ((2048, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fuse_add_norm_prefill (decode)
    test_operator(
        "fuse_add_norm_prefill",
        [
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul12
    #   T.Buffer((22016, 256), 'uint32')
    #   T.Buffer((22016, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 22016), 'float16')
    test_operator(
        "fused_dequantize3_NT_matmul12",
        [
            ((22016, 256), "uint32"),
            ((22016, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 22016), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split2_silu2_multiply2
    #   T.Buffer((1, 1, 22016), 'float16')
    #   T.Buffer((1, 1, 11008), 'float16')
    test_operator(
        "fused_split2_silu2_multiply2",
        [
            ((1, 1, 22016), "float16"),
            ((1, 1, 11008), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_NT_matmul13
    #   T.Buffer((2048, 1376), 'uint32')
    #   T.Buffer((2048, 344), 'float16')
    #   T.Buffer((1, 1, 11008), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_dequantize4_NT_matmul13",
        [
            ((2048, 1376), "uint32"),
            ((2048, 344), "float16"),
            ((1, 1, 11008), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize_NT_matmul14 (decode)
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, vocab_size), 'float32')
    for vocab_size in VOCAB_SIZES:
        test_operator(
            "fused_dequantize_NT_matmul14",
            [
                ((vocab_size, 256), "uint32"),
                ((vocab_size, 64), "float16"),
                ((1, 1, 2048), "float16"),
                ((1, 1, vocab_size), "float32"),
            ],
            golden_lib,
            buggy_lib,
        )

    # # --- attn_kernels ---
    # # batch_decode_paged_kv
    # #   T.int32
    # #   T.Buffer((B, 32, 128), 'float16')
    # #   T.Buffer((max_num_pages, 2, 8, 16, 128), 'float16')
    # #   T.Buffer((B + 1,), 'int32')
    # #   T.Buffer((nnz_pages,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B, 32, 64), 'float16')
    # #   T.Buffer((B, 32))
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_decode_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((2, 32, 64), "float16"),
    #         ((4, 2, 8, 16, 64), "float16"),
    #         ((3,), "int32"),
    #         ((10,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((2, 32, 64), "float16"),
    #         ((2, 32), "float16"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_paged_kv
    # #   T.int32
    # #   T.Buffer((total_len, 32, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((max_num_pages,2,8,16,64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((nnz_pages,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((total_len,), 'int32')
    # #   T.Buffer((total_len,32,64), 'float16')
    # #   T.Buffer((total_len,32))
    # #   T.int32
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_prefill_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((10, 32, 64), "float16"),
    #         ((3,), "int32"),
    #         ((4, 2, 8, 16, 64), "float16"),
    #         ((3,), "int32"),
    #         ((10,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((10,), "int32"),
    #         ((10, 32, 64), "float16"),
    #         ((10, 32), "float32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_ragged_kv
    # #   T.Buffer((qo_len, 32, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((kv_len, 8, 64), 'float16')
    # #   T.Buffer((kv_len, 8, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((qo_len,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((qo_len, 32, 64), 'float16')
    # #   T.Buffer((qo_len, 32))
    # #   T.int32
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_prefill_ragged_kv",
    #     [
    #         ((5, 32, 64), "float16"),
    #         ((3,), "int32"),
    #         ((7, 8, 64), "float16"),
    #         ((7, 8, 64), "float16"),
    #         ((3,), "int32"),
    #         ((5,), "int32"),
    #         ((2,), "int32"),
    #         ((5, 32, 64), "float16"),
    #         ((5, 32), "float32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # fused_rope
    #   T.Buffer((seq_len, 20, 128), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 16, 128), 'float16')
    #   T.Buffer((seq_len, 2, 128), 'float16')
    #   T.Buffer((seq_len, 2, 128), 'float16')
    #   T.int32
    for seq_len in SEQ_CANDIDATES:
        test_operator(
            "fused_rope",
            [
                ((seq_len, 20, 128), "float16"),
                ((seq_len,), "int32"),
                ((seq_len, 16, 128), "float16"),
                ((seq_len, 2, 128), "float16"),
                ((seq_len, 2, 128), "float16"),
                ((), "int32"),
            ],
            golden_lib,
            buggy_lib,
        )

    print("Detected buggy operators: ")
    for op, (corrects, fails) in OP_RESULTS.items():
        if fails != 0:
            print(f"Operator {op} passes {corrects} tests but fails {fails} tests")


if __name__ == "__main__":
    main()
