import tvm
import numpy as np
import os

current_dir = os.path.dirname(os.path.abspath(__file__))

GOLDEN_DEVICE = "cuda"
GOLDEN_LIB = os.path.join(current_dir, "cuda.so")
BUGGY_DEVICE = "vulkan"
BUGGY_LIB = os.path.join(current_dir, "vulkan.so")
ATOL = 1e-1
RTOL = 1e-3

BUGGY_OPS = []


def random_array(shape, dtype):
    # For integer dtypes, use randint; for floats, use randn.
    if "int" in dtype or "uint" in dtype:
        if len(shape) == 0:
            return np.array(np.random.randint(0, 100), dtype=dtype).item()
        return np.random.randint(0, 100, size=shape).astype(dtype)

    if len(shape) == 0:
        return np.array(np.random.randn(*shape), dtype=dtype).item()

    return np.random.randn(*shape).astype(dtype)


def test_operator(op, types, golden_lib, buggy_lib):
    inputs = []
    for shape, dtype in types:
        inputs.append(random_array(shape, dtype))

    golden_dev = tvm.device(GOLDEN_DEVICE)
    buggy_dev = tvm.device(BUGGY_DEVICE)

    golden_inputs = [
        tvm.nd.array(arr, device=golden_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]
    buggy_inputs = [
        tvm.nd.array(arr, device=buggy_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]

    golden_func = golden_lib.get_function(op, query_imports=True)
    buggy_func = buggy_lib.get_function(op, query_imports=True)

    golden_func(*golden_inputs)
    buggy_func(*buggy_inputs)

    # Compare the output in the last buffer.
    golden_np = (
        golden_inputs[-1].numpy()
        if isinstance(golden_inputs[-1], tvm.nd.NDArray)
        else golden_inputs[-1]
    )
    buggy_np = (
        buggy_inputs[-1].numpy()
        if isinstance(buggy_inputs[-1], tvm.nd.NDArray)
        else buggy_inputs[-1]
    )

    if "matmul" in op:
        atol, rtol = 5e-1, 1e-2
    else:
        atol, rtol = ATOL, RTOL

    if np.allclose(golden_np, buggy_np, atol=atol, rtol=rtol):
        print(f"✅ Operator {op} matches within tolerance (atol={atol}, rtol={rtol})")
    else:
        diff = np.max(np.abs(golden_np - buggy_np))
        print(f"❌ Operator {op} mismatch! Max diff: {diff}")
        BUGGY_OPS.append(op)


def main():
    golden_lib = tvm.runtime.load_module(GOLDEN_LIB)
    buggy_lib = tvm.runtime.load_module(BUGGY_LIB)

    # --- embed ---
    # fused_dequantize_take1
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 2048), 'float16')
    test_operator(
        "fused_dequantize_take1",
        [
            ((10000, 256), "uint32"),
            ((10000, 64), "float16"),
            ((128,), "int32"),
            ((128, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # --- prefill ---
    # rms_norm1
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    test_operator(
        "rms_norm1",
        [
            ((1, 128, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 128, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_NT_matmul5
    #   T.Buffer((3072, 256), 'uint32')
    #   T.Buffer((3072, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 3072), 'float16')
    test_operator(
        "fused_dequantize1_NT_matmul5",
        [
            ((3072, 256), "uint32"),
            ((3072, 64), "float16"),
            ((1, 128, 2048), "float16"),
            ((1, 128, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape4
    #   T.Buffer((1, seq_len, 3072), 'float16')
    #   T.Buffer((1, seq_len, 48, 64), 'float16')
    test_operator(
        "reshape4",
        [
            ((1, 128, 3072), "float16"),
            ((1, 128, 48, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape5
    #   T.Buffer((1, seq_len, 48, 64), 'float16')
    #   T.Buffer((seq_len, 48, 64), 'float16')
    test_operator(
        "reshape5",
        [
            ((1, 128, 48, 64), "float16"),
            ((128, 48, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape6
    #   T.Buffer((seq_len, 32, 64), 'float16')
    #   T.Buffer((1, seq_len, 32, 64), 'float16')
    test_operator(
        "reshape6",
        [
            ((128, 32, 64), "float16"),
            ((1, 128, 32, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape7
    #   T.Buffer((1, seq_len, 32, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    test_operator(
        "reshape7",
        [
            ((1, 128, 32, 64), "float16"),
            ((1, 128, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_NT_matmul6
    #   T.Buffer((2048, 256), 'uint32')
    #   T.Buffer((2048, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    test_operator(
        "fused_dequantize2_NT_matmul6",
        [
            ((2048, 256), "uint32"),
            ((2048, 64), "float16"),
            ((1, 128, 2048), "float16"),
            ((1, 128, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fuse_add_norm_prefill
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    test_operator(
        "fuse_add_norm_prefill",
        [
            ((1, 128, 2048), "float16"),
            ((1, 128, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 128, 2048), "float16"),
            ((1, 128, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul7
    #   T.Buffer((16384, 256), 'uint32')
    #   T.Buffer((16384, 64), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, seq_len, 16384), 'float16')
    test_operator(
        "fused_dequantize3_NT_matmul7",
        [
            ((16384, 256), "uint32"),
            ((16384, 64), "float16"),
            ((1, 128, 2048), "float16"),
            ((1, 128, 16384), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split1_silu1_multiply1
    #   T.Buffer((1, seq_len, 16384), 'float16')
    #   T.Buffer((1, seq_len, 8192), 'float16')
    test_operator(
        "fused_split1_silu1_multiply1",
        [
            ((1, 128, 16384), "float16"),
            ((1, 128, 8192), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_NT_matmul8
    #   T.Buffer((2048, 1024), 'uint32')
    #   T.Buffer((2048, 256), 'float16')
    #   T.Buffer((1, seq_len, 8192), 'float16')
    #   T.Buffer((1, seq_len, 2048), 'float16')
    test_operator(
        "fused_dequantize4_NT_matmul8",
        [
            ((2048, 1024), "uint32"),
            ((2048, 256), "float16"),
            ((1, 128, 8192), "float16"),
            ((1, 128, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # index
    #   T.Buffer((1, seq_len, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "index",
        [
            ((1, 128, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize_NT_matmul14
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, vocab_size), 'float16')
    test_operator(
        "fused_dequantize_NT_matmul14",
        [
            ((10000, 256), "uint32"),
            ((10000, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 10000), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # --- decode ---
    # rms_norm2
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((2048,), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "rms_norm2",
        [
            ((1, 1, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_NT_matmul10
    #   T.Buffer((3072, 256), 'uint32')
    #   T.Buffer((3072, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 3072), 'float16')
    test_operator(
        "fused_dequantize1_NT_matmul10",
        [
            ((3072, 256), "uint32"),
            ((3072, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape8_reshape9
    #   T.Buffer((1, 1, 3072), 'float16')
    #   T.Buffer((1, 48, 64), 'float16')
    test_operator(
        "fused_reshape8_reshape9",
        [
            ((1, 1, 3072), "float16"),
            ((1, 48, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape10_reshape11
    #   T.Buffer((1, 32, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_reshape10_reshape11",
        [
            ((1, 32, 64), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_NT_matmul11
    #   T.Buffer((2048, 256), 'uint32')
    #   T.Buffer((2048, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_dequantize2_NT_matmul11",
        [
            ((2048, 256), "uint32"),
            ((2048, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fuse_add_norm_prefill (decode)
    test_operator(
        "fuse_add_norm_prefill",
        [
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
            ((2048,), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul12
    #   T.Buffer((16384, 256), 'uint32')
    #   T.Buffer((16384, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, 16384), 'float16')
    test_operator(
        "fused_dequantize3_NT_matmul12",
        [
            ((16384, 256), "uint32"),
            ((16384, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 16384), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split2_silu2_multiply2
    #   T.Buffer((1, 1, 16384), 'float16')
    #   T.Buffer((1, 1, 8192), 'float16')
    test_operator(
        "fused_split2_silu2_multiply2",
        [
            ((1, 1, 16384), "float16"),
            ((1, 1, 8192), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_NT_matmul13
    #   T.Buffer((2048, 1024), 'uint32')
    #   T.Buffer((2048, 256), 'float16')
    #   T.Buffer((1, 1, 8192), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    test_operator(
        "fused_dequantize4_NT_matmul13",
        [
            ((2048, 1024), "uint32"),
            ((2048, 256), "float16"),
            ((1, 1, 8192), "float16"),
            ((1, 1, 2048), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize_NT_matmul14 (decode)
    #   T.Buffer((vocab_size, 256), 'uint32')
    #   T.Buffer((vocab_size, 64), 'float16')
    #   T.Buffer((1, 1, 2048), 'float16')
    #   T.Buffer((1, 1, vocab_size), 'float16')
    test_operator(
        "fused_dequantize_NT_matmul14",
        [
            ((10000, 256), "uint32"),
            ((10000, 64), "float16"),
            ((1, 1, 2048), "float16"),
            ((1, 1, 10000), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # # --- attn_kernels ---
    # # batch_decode_paged_kv
    # #   T.int32
    # #   T.Buffer((B, 32, 64), 'float16')
    # #   T.Buffer((max_num_pages,2,8,16,64), 'float16')
    # #   T.Buffer((B+1,), 'int32')
    # #   T.Buffer((nnz_pages,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B,), 'int32')
    # #   T.Buffer((B, 32, 64), 'float16')
    # #   T.Buffer((B, 32))
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_decode_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((2, 32, 64), "float16"),
    #         ((4, 2, 8, 16, 64), "float16"),
    #         ((3,), "int32"),
    #         ((10,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((2, 32, 64), "float16"),
    #         ((2, 32), "float16"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_paged_kv
    # #   T.int32
    # #   T.Buffer((total_len, 32, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((max_num_pages,2,8,16,64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((nnz_pages,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((total_len,), 'int32')
    # #   T.Buffer((total_len,32,64), 'float16')
    # #   T.Buffer((total_len,32))
    # #   T.int32
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_prefill_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((10, 32, 64), "float16"),
    #         ((3,), "int32"),
    #         ((4, 2, 8, 16, 64), "float16"),
    #         ((3,), "int32"),
    #         ((10,), "int32"),
    #         ((2,), "int32"),
    #         ((2,), "int32"),
    #         ((10,), "int32"),
    #         ((10, 32, 64), "float16"),
    #         ((10, 32), "float32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_ragged_kv
    # #   T.Buffer((qo_len, 32, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((kv_len, 8, 64), 'float16')
    # #   T.Buffer((kv_len, 8, 64), 'float16')
    # #   T.Buffer((batch_size+1,), 'int32')
    # #   T.Buffer((qo_len,), 'int32')
    # #   T.Buffer((batch_size,), 'int32')
    # #   T.Buffer((qo_len, 32, 64), 'float16')
    # #   T.Buffer((qo_len, 32))
    # #   T.int32
    # #   T.int32
    # #   T.float32
    # #   T.float32
    # #   T.float32
    # test_operator(
    #     "batch_prefill_ragged_kv",
    #     [
    #         ((5, 32, 64), "float16"),
    #         ((3,), "int32"),
    #         ((7, 8, 64), "float16"),
    #         ((7, 8, 64), "float16"),
    #         ((3,), "int32"),
    #         ((5,), "int32"),
    #         ((2,), "int32"),
    #         ((5, 32, 64), "float16"),
    #         ((5, 32), "float32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # fused_rope
    #   T.Buffer((seq_len, 48, 64), 'float16')
    #   T.Buffer((seq_len,), 'int32')
    #   T.Buffer((seq_len, 32, 64), 'float16')
    #   T.Buffer((seq_len, 8, 64), 'float16')
    #   T.Buffer((seq_len, 8, 64), 'float16')
    #   T.int32
    test_operator(
        "fused_rope",
        [
            ((128, 48, 64), "float16"),
            ((128,), "int32"),
            ((128, 32, 64), "float16"),
            ((128, 8, 64), "float16"),
            ((128, 8, 64), "float16"),
            ((), "int32"),
        ],
        golden_lib,
        buggy_lib,
    )

    print(f"Detected buggy operators: {BUGGY_OPS}")
    current_dir = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(current_dir, "bottom-up.txt"), "w") as f:
        f.write("\n".join(BUGGY_OPS))


if __name__ == "__main__":
    main()
