import tvm
import numpy as np
import os

current_dir = os.path.dirname(os.path.abspath(__file__))

GOLDEN_DEVICE = "cuda"
GOLDEN_LIB = os.path.join(current_dir, "cuda.so")
BUGGY_DEVICE = "vulkan"
BUGGY_LIB = os.path.join(current_dir, "vulkan.so")
ATOL = 1e-1
RTOL = 1e-3

BUGGY_OPS = []

# Concrete values for symbolic dimensions.
SEQ_LEN = 128
B = 4
batch_size = 4
total_len = 20
max_num_pages = 3
nnz_pages = 5
qo_len = 10
kv_len = 8


def random_array(shape, dtype):
    # For integer dtypes, use randint; for floats, use randn.
    if "int" in dtype or "uint" in dtype:
        if len(shape) == 0:
            return np.array(np.random.randint(0, 100), dtype=dtype).item()
        return np.random.randint(0, 100, size=shape).astype(dtype)

    if len(shape) == 0:
        return np.array(np.random.randn(), dtype=dtype).item()

    return np.random.randn(*shape).astype(dtype)


def test_operator(op, types, golden_lib, buggy_lib):
    inputs = []
    for shape, dtype in types:
        inputs.append(random_array(shape, dtype))

    golden_dev = tvm.device(GOLDEN_DEVICE)
    buggy_dev = tvm.device(BUGGY_DEVICE)

    golden_inputs = [
        tvm.nd.array(arr, device=golden_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]
    buggy_inputs = [
        tvm.nd.array(arr, device=buggy_dev) if isinstance(arr, np.ndarray) else arr
        for arr in inputs
    ]

    golden_func = golden_lib.get_function(op, query_imports=True)
    buggy_func = buggy_lib.get_function(op, query_imports=True)

    golden_func(*golden_inputs)
    buggy_func(*buggy_inputs)

    # Compare the output in the last buffer.
    golden_np = (
        golden_inputs[-1].numpy()
        if isinstance(golden_inputs[-1], tvm.nd.NDArray)
        else golden_inputs[-1]
    )
    buggy_np = (
        buggy_inputs[-1].numpy()
        if isinstance(buggy_inputs[-1], tvm.nd.NDArray)
        else buggy_inputs[-1]
    )

    if "matmul" in op:
        atol, rtol = 5e-1, 1e-2
    else:
        atol, rtol = ATOL, RTOL

    if np.allclose(golden_np, buggy_np, atol=atol, rtol=rtol):
        print(f"✅ Operator {op} matches within tolerance (atol={atol}, rtol={rtol})")
    else:
        diff = np.max(np.abs(golden_np - buggy_np))
        print(f"❌ Operator {op} mismatch! Max diff: {diff}")
        BUGGY_OPS.append(op)


def main():
    golden_lib = tvm.runtime.load_module(GOLDEN_LIB)
    buggy_lib = tvm.runtime.load_module(BUGGY_LIB)

    # =================== EMBED ===================
    # fused_dequantize_take1
    test_operator(
        "fused_dequantize_take1",
        [
            ((151936, 112), "uint32"),
            ((151936, 28), "float16"),
            ((SEQ_LEN,), "int32"),
            ((SEQ_LEN, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # =================== PREFILL ===================
    # rms_norm1
    test_operator(
        "rms_norm1",
        [
            ((1, SEQ_LEN, 896), "float16"),
            ((896,), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_fused_NT_matmul5_add2
    test_operator(
        "fused_dequantize1_fused_NT_matmul5_add2",
        [
            ((1152, 112), "uint32"),
            ((1152, 28), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
            ((1152,), "float16"),
            ((1, SEQ_LEN, 1152), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape4
    test_operator(
        "reshape4",
        [
            ((1, SEQ_LEN, 1152), "float16"),
            ((1, SEQ_LEN, 18, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape5
    test_operator(
        "reshape5",
        [
            ((1, SEQ_LEN, 18, 64), "float16"),
            ((SEQ_LEN, 18, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape6
    test_operator(
        "reshape6",
        [
            ((SEQ_LEN, 14, 64), "float16"),
            ((1, SEQ_LEN, 14, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # reshape7
    test_operator(
        "reshape7",
        [
            ((1, SEQ_LEN, 14, 64), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_fused_NT_matmul6_add3
    test_operator(
        "fused_dequantize2_fused_NT_matmul6_add3",
        [
            ((896, 112), "uint32"),
            ((896, 28), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul7
    test_operator(
        "fused_dequantize3_NT_matmul7",
        [
            ((9728, 112), "uint32"),
            ((9728, 28), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
            ((1, SEQ_LEN, 9728), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split1_silu1_multiply1
    test_operator(
        "fused_split1_silu1_multiply1",
        [
            ((1, SEQ_LEN, 9728), "float16"),
            ((1, SEQ_LEN, 4864), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_fused_NT_matmul8_add3
    test_operator(
        "fused_dequantize4_fused_NT_matmul8_add3",
        [
            ((896, 608), "uint32"),
            ((896, 152), "float16"),
            ((1, SEQ_LEN, 4864), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
            ((1, SEQ_LEN, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # index
    test_operator(
        "index",
        [
            ((1, SEQ_LEN, 896), "float16"),
            ((1, 1, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize_NT_matmul14 (from embed)
    test_operator(
        "fused_dequantize_NT_matmul14",
        [
            ((151936, 112), "uint32"),
            ((151936, 28), "float16"),
            ((1, 1, 896), "float16"),
            ((1, 1, 151936), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # =================== DECODE ===================
    # rms_norm2
    test_operator(
        "rms_norm2",
        [
            ((1, 1, 896), "float16"),
            ((896,), "float16"),
            ((1, 1, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize1_fused_NT_matmul10_add4
    test_operator(
        "fused_dequantize1_fused_NT_matmul10_add4",
        [
            ((1152, 112), "uint32"),
            ((1152, 28), "float16"),
            ((1, 1, 896), "float16"),
            ((1152,), "float16"),
            ((1, 1, 1152), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape8_reshape9
    test_operator(
        "fused_reshape8_reshape9",
        [
            ((1, 1, 1152), "float16"),
            ((1, 18, 64), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_reshape10_reshape11
    test_operator(
        "fused_reshape10_reshape11",
        [
            ((1, 14, 64), "float16"),
            ((1, 1, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize2_fused_NT_matmul11_add5
    test_operator(
        "fused_dequantize2_fused_NT_matmul11_add5",
        [
            ((896, 112), "uint32"),
            ((896, 28), "float16"),
            ((1, 1, 896), "float16"),
            ((1, 1, 896), "float16"),
            ((1, 1, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize3_NT_matmul12
    test_operator(
        "fused_dequantize3_NT_matmul12",
        [
            ((9728, 112), "uint32"),
            ((9728, 28), "float16"),
            ((1, 1, 896), "float16"),
            ((1, 1, 9728), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_split2_silu2_multiply2
    test_operator(
        "fused_split2_silu2_multiply2",
        [
            ((1, 1, 9728), "float16"),
            ((1, 1, 4864), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize4_fused_NT_matmul13_add5
    test_operator(
        "fused_dequantize4_fused_NT_matmul13_add5",
        [
            ((896, 608), "uint32"),
            ((896, 152), "float16"),
            ((1, 1, 4864), "float16"),
            ((1, 1, 896), "float16"),
            ((1, 1, 896), "float16"),
        ],
        golden_lib,
        buggy_lib,
    )

    # fused_dequantize_NT_matmul14 (from decode)
    test_operator(
        "fused_dequantize_NT_matmul14",
        [
            ((151936, 112), "uint32"),
            ((151936, 28), "float16"),
            ((1, 1, 896), "float16"),
            ((1, 1, 151936), "float32"),
        ],
        golden_lib,
        buggy_lib,
    )

    # # =================== ATTN_KERNELS ===================
    # # batch_decode_paged_kv
    # test_operator(
    #     "batch_decode_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((B, 14, 64), "float16"),
    #         ((max_num_pages, 2, 2, 16, 64), "float16"),
    #         ((B + 1,), "int32"),
    #         ((nnz_pages,), "int32"),
    #         ((B,), "int32"),
    #         ((B,), "int32"),
    #         ((B,), "int32"),
    #         ((B, 14, 64), "float16"),
    #         ((B, 14), "int32"),  # assumed type for this buffer
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_paged_kv
    # test_operator(
    #     "batch_prefill_paged_kv",
    #     [
    #         ((), "int32"),
    #         ((total_len, 14, 64), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((max_num_pages, 2, 2, 16, 64), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((nnz_pages,), "int32"),
    #         ((batch_size,), "int32"),
    #         ((batch_size,), "int32"),
    #         ((total_len,), "int32"),
    #         ((total_len, 14, 64), "float16"),
    #         ((total_len, 14), "int32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # # batch_prefill_ragged_kv
    # test_operator(
    #     "batch_prefill_ragged_kv",
    #     [
    #         ((qo_len, 14, 64), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((kv_len, 2, 64), "float16"),
    #         ((kv_len, 2, 64), "float16"),
    #         ((batch_size + 1,), "int32"),
    #         ((qo_len,), "int32"),
    #         ((batch_size,), "int32"),
    #         ((qo_len, 14, 64), "float16"),
    #         ((qo_len, 14), "int32"),
    #         ((), "int32"),
    #         ((), "int32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #         ((), "float32"),
    #     ],
    #     golden_lib,
    #     buggy_lib,
    # )

    # fused_rope
    test_operator(
        "fused_rope",
        [
            ((SEQ_LEN, 18, 64), "float16"),
            ((SEQ_LEN,), "int32"),
            ((SEQ_LEN, 14, 64), "float16"),
            ((SEQ_LEN, 2, 64), "float16"),
            ((SEQ_LEN, 2, 64), "float16"),
            ((), "int32"),
        ],
        golden_lib,
        buggy_lib,
    )

    print(f"Detected buggy operators: {BUGGY_OPS}")
    with open(os.path.join(current_dir, "bottom-up.txt"), "w") as f:
        f.write("\n".join(BUGGY_OPS))


if __name__ == "__main__":
    main()
