import tvm
import numpy as np
import os

current_dir = os.path.dirname(os.path.abspath(__file__))

GOLDEN_DEVICE = "cuda"
GOLDEN_LIB = os.path.join(current_dir, "cuda.so")
BUGGY_DEVICE = "vulkan"
BUGGY_LIB = os.path.join(current_dir, "vulkan.so")
ATOL = 1e-1
RTOL = 1e-3

BUGGY_OPS = []


def random_array(shape, dtype):
    # For integer dtypes, use randint; for floats, use randn.
    if "int" in dtype or "uint" in dtype:
        if len(shape) == 0:
            return np.array(np.random.randint(0, 100), dtype=dtype).item()
        return np.random.randint(0, 100, size=shape).astype(dtype)

    if len(shape) == 0:
        return np.array(np.random.randn(), dtype=dtype).item()
    return np.random.randn(*shape).astype(dtype)


def test_operator(op, types, golden_lib, buggy_lib):
    inputs = []
    for shape, dtype in types:
        inputs.append(random_array(shape, dtype))

    golden_dev = tvm.device(GOLDEN_DEVICE)
    buggy_dev = tvm.device(BUGGY_DEVICE)

    golden_inputs = [
        tvm.nd.array(arr, device=golden_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]
    buggy_inputs = [
        tvm.nd.array(arr, device=buggy_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]

    golden_func = golden_lib.get_function(op, query_imports=True)
    buggy_func = buggy_lib.get_function(op, query_imports=True)

    golden_func(*golden_inputs)
    buggy_func(*buggy_inputs)

    # Compare the output in the last buffer.
    golden_np = (
        golden_inputs[-1].numpy()
        if isinstance(golden_inputs[-1], tvm.nd.NDArray)
        else golden_inputs[-1]
    )
    buggy_np = (
        buggy_inputs[-1].numpy()
        if isinstance(buggy_inputs[-1], tvm.nd.NDArray)
        else buggy_inputs[-1]
    )

    if "matmul" in op:
        atol, rtol = 5e-1, 1e-2
    else:
        atol, rtol = ATOL, RTOL

    if np.allclose(golden_np, buggy_np, atol=atol, rtol=rtol):
        print(f"✅ Operator {op} matches within tolerance (atol={atol}, rtol={rtol})")
    else:
        diff = np.max(np.abs(golden_np - buggy_np))
        print(f"❌ Operator {op} mismatch! Max diff: {diff}")
        BUGGY_OPS.append(op)


def main():
    golden_lib = tvm.runtime.load_module(GOLDEN_LIB)
    buggy_lib = tvm.runtime.load_module(BUGGY_LIB)

    # Use fixed dimensions for dynamic parameters.
    seq_len = 10
    vocab_size = 10000
    B = 2
    max_num_pages = 3
    nnz_pages = 5
    total_len = 20
    batch_size = 2
    qo_len = 7
    kv_len = 10

    # fused_dequantize_take1
    test_operator(
        "fused_dequantize_take1",
        [
            ((32064, 384), "uint32"),
            ((32064, 96), "float16"),
            ((seq_len,), "int32"),
            ((seq_len, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # -- prefill --
    # rms_norm1
    test_operator(
        "rms_norm1",
        [
            ((1, seq_len, 3072), "float16"),
            ((3072,), "float16"),
            ((1, seq_len, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_NT_matmul5
    test_operator(
        "fused_dequantize1_NT_matmul5",
        [
            ((9216, 384), "uint32"),
            ((9216, 96), "float16"),
            ((1, seq_len, 3072), "float16"),
            ((1, seq_len, 9216), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape4
    test_operator(
        "reshape4",
        [
            ((1, seq_len, 9216), "float16"),
            ((1, seq_len, 96, 96), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape5
    test_operator(
        "reshape5",
        [
            ((1, seq_len, 96, 96), "float16"),
            ((seq_len, 96, 96), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape6
    test_operator(
        "reshape6",
        [
            ((seq_len, 32, 96), "float16"),
            ((1, seq_len, 32, 96), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape7
    test_operator(
        "reshape7",
        [
            ((1, seq_len, 32, 96), "float16"),
            ((1, seq_len, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_NT_matmul6
    test_operator(
        "fused_dequantize2_NT_matmul6",
        [
            ((3072, 384), "uint32"),
            ((3072, 96), "float16"),
            ((1, seq_len, 3072), "float16"),
            ((1, seq_len, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fuse_add_norm_prefill
    test_operator(
        "fuse_add_norm_prefill",
        [
            ((1, seq_len, 3072), "float16"),
            ((1, seq_len, 3072), "float16"),
            ((3072,), "float16"),
            ((1, seq_len, 3072), "float16"),
            ((1, seq_len, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul7
    test_operator(
        "fused_dequantize3_NT_matmul7",
        [
            ((16384, 384), "uint32"),
            ((16384, 96), "float16"),
            ((1, seq_len, 3072), "float16"),
            ((1, seq_len, 16384), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split1_silu1_multiply1
    test_operator(
        "fused_split1_silu1_multiply1",
        [
            ((1, seq_len, 16384), "float16"),
            ((1, seq_len, 8192), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_NT_matmul8
    test_operator(
        "fused_dequantize4_NT_matmul8",
        [
            ((3072, 1024), "uint32"),
            ((3072, 256), "float16"),
            ((1, seq_len, 8192), "float16"),
            ((1, seq_len, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # index
    test_operator(
        "index",
        [
            ((1, seq_len, 3072), "float16"),
            ((1, 1, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize5_fused_NT_matmul14_cast2 (prefill section)
    test_operator(
        "fused_dequantize5_fused_NT_matmul14_cast2",
        [
            ((vocab_size, 384), "uint32"),
            ((vocab_size, 96), "float16"),
            ((1, 1, 3072), "float16"),
            ((1, 1, vocab_size), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # -- decode --
    # rms_norm2
    test_operator(
        "rms_norm2",
        [
            ((1, 1, 3072), "float16"),
            ((3072,), "float16"),
            ((1, 1, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_NT_matmul10
    test_operator(
        "fused_dequantize1_NT_matmul10",
        [
            ((9216, 384), "uint32"),
            ((9216, 96), "float16"),
            ((1, 1, 3072), "float16"),
            ((1, 1, 9216), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape8_reshape9
    test_operator(
        "fused_reshape8_reshape9",
        [
            ((1, 1, 9216), "float16"),
            ((1, 96, 96), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape10_reshape11
    test_operator(
        "fused_reshape10_reshape11",
        [
            ((1, 32, 96), "float16"),
            ((1, 1, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_NT_matmul11
    test_operator(
        "fused_dequantize2_NT_matmul11",
        [
            ((3072, 384), "uint32"),
            ((3072, 96), "float16"),
            ((1, 1, 3072), "float16"),
            ((1, 1, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul12
    test_operator(
        "fused_dequantize3_NT_matmul12",
        [
            ((16384, 384), "uint32"),
            ((16384, 96), "float16"),
            ((1, 1, 3072), "float16"),
            ((1, 1, 16384), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split2_silu2_multiply2
    test_operator(
        "fused_split2_silu2_multiply2",
        [
            ((1, 1, 16384), "float16"),
            ((1, 1, 8192), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_NT_matmul13
    test_operator(
        "fused_dequantize4_NT_matmul13",
        [
            ((3072, 1024), "uint32"),
            ((3072, 256), "float16"),
            ((1, 1, 8192), "float16"),
            ((1, 1, 3072), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize5_fused_NT_matmul14_cast2 (decode section)
    test_operator(
        "fused_dequantize5_fused_NT_matmul14_cast2",
        [
            ((vocab_size, 384), "uint32"),
            ((vocab_size, 96), "float16"),
            ((1, 1, 3072), "float16"),
            ((1, 1, vocab_size), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # # -- attn_kernels --
    # # batch_decode_paged_kv
    # test_operator(
    #     "batch_decode_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((B, 32, 96), "float16"),
    #         ((max_num_pages, 2, 32, 16, 96), "float16"),
    #         ((B + 1,), "int32"),
    #         ((nnz_pages,), "int32"),
    #         ((B,), "int32"),
    #         ((B,), "int32"),
    #         ((B,), "int32"),
    #         ((B, 32, 96), "float16"),
    #         ((B, 32), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_paged_kv
    # test_operator(
    #     "batch_prefill_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((total_len, 32, 96), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((max_num_pages, 2, 32, 16, 96), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((nnz_pages,), "int32"),
    #         ((batch_size,), "int32"),
    #         ((batch_size,), "int32"),
    #         ((total_len,), "int32"),
    #         ((total_len, 32, 96), "float16"),
    #         ((total_len, 32), "int32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_ragged_kv
    # test_operator(
    #     "batch_prefill_ragged_kv",
    #     [
    #         ((qo_len, 32, 96), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((kv_len, 32, 96), "float16"),
    #         ((kv_len, 32, 96), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((qo_len,), "int32"),
    #         ((batch_size,), "int32"),
    #         ((qo_len, 32, 96), "float16"),
    #         ((qo_len, 32), "int32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    print(f"Detected buggy operators: {BUGGY_OPS}")
    with open(os.path.join(current_dir, "bottom-up.txt"), "w") as f:
        f.write("\n".join(set(BUGGY_OPS)))


if __name__ == "__main__":
    main()
