=============embed=============
fused_dequantize_take1
T.Buffer((151936, 112), 'uint32')
T.Buffer((151936, 28), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 896), 'float16')
--------------------------------
=============prefill=============
rms_norm1
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(896),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
--------------------------------
fused_dequantize1_fused_NT_matmul5_add2
T.Buffer((T.int64(1152), T.int64(112)), 'uint32')
T.Buffer((T.int64(1152), T.int64(28)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(1152),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(1152)), 'float16')
--------------------------------
reshape4
T.Buffer((T.int64(1), seq_len, T.int64(1152)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(18), T.int64(64)), 'float16')
--------------------------------
reshape5
T.Buffer((T.int64(1), seq_len, T.int64(18), T.int64(64)), 'float16')
T.Buffer((seq_len, T.int64(18), T.int64(64)), 'float16')
--------------------------------
reshape6
T.Buffer((seq_len, T.int64(14), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(14), T.int64(64)), 'float16')
--------------------------------
reshape7
T.Buffer((T.int64(1), seq_len, T.int64(14), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
--------------------------------
fused_dequantize2_fused_NT_matmul6_add3
T.Buffer((T.int64(896), T.int64(112)), 'uint32')
T.Buffer((T.int64(896), T.int64(28)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
--------------------------------
fused_dequantize3_NT_matmul7
T.Buffer((T.int64(9728), T.int64(112)), 'uint32')
T.Buffer((T.int64(9728), T.int64(28)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(9728)), 'float16')
--------------------------------
fused_split1_silu1_multiply1
T.Buffer((T.int64(1), seq_len, T.int64(9728)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(4864)), 'float16')
--------------------------------
fused_dequantize4_fused_NT_matmul8_add3
T.Buffer((T.int64(896), T.int64(608)), 'uint32')
T.Buffer((T.int64(896), T.int64(152)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(4864)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
--------------------------------
index
T.Buffer((T.int64(1), seq_len, T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
--------------------------------
fused_dequantize_NT_matmul14
T.Buffer((T.int64(151936), T.int64(112)), 'uint32')
T.Buffer((T.int64(151936), T.int64(28)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), 'float32')
--------------------------------
=============decode=============
rms_norm2
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(896),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
--------------------------------
fused_dequantize1_fused_NT_matmul10_add4
T.Buffer((T.int64(1152), T.int64(112)), 'uint32')
T.Buffer((T.int64(1152), T.int64(28)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1152),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(1152)), 'float16')
--------------------------------
fused_reshape8_reshape9
T.Buffer((T.int64(1), T.int64(1), T.int64(1152)), 'float16')
T.Buffer((T.int64(1), T.int64(18), T.int64(64)), 'float16')
--------------------------------
fused_reshape10_reshape11
T.Buffer((T.int64(1), T.int64(14), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
--------------------------------
fused_dequantize2_fused_NT_matmul11_add5
T.Buffer((T.int64(896), T.int64(112)), 'uint32')
T.Buffer((T.int64(896), T.int64(28)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
--------------------------------
fused_dequantize3_NT_matmul12
T.Buffer((T.int64(9728), T.int64(112)), 'uint32')
T.Buffer((T.int64(9728), T.int64(28)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(9728)), 'float16')
--------------------------------
fused_split2_silu2_multiply2
T.Buffer((T.int64(1), T.int64(1), T.int64(9728)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(4864)), 'float16')
--------------------------------
fused_dequantize4_fused_NT_matmul13_add5
T.Buffer((T.int64(896), T.int64(608)), 'uint32')
T.Buffer((T.int64(896), T.int64(152)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(4864)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
--------------------------------
fused_dequantize_NT_matmul14
T.Buffer((T.int64(151936), T.int64(112)), 'uint32')
T.Buffer((T.int64(151936), T.int64(28)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(896)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), 'float32')
--------------------------------
=============attn_kernels=============
batch_decode_paged_kv
T.int32
T.Buffer((B, 14, 64), 'float16')
T.Buffer((max_num_pages, 2, 2, 16, 64), 'float16')
T.Buffer((B + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B, 14, 64), 'float16')
T.Buffer((B, 14))
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_paged_kv
T.int32
T.Buffer((total_len, 14, 64), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((max_num_pages, 2, 2, 16, 64), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((total_len,), 'int32')
T.Buffer((total_len, 14, 64), 'float16')
T.Buffer((total_len, 14))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_ragged_kv
T.Buffer((qo_len, 14, 64), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((kv_len, 2, 64), 'float16')
T.Buffer((kv_len, 2, 64), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((qo_len,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((qo_len, 14, 64), 'float16')
T.Buffer((qo_len, 14))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
fused_rope
T.Buffer((seq_len, 18, 64), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 14, 64), 'float16')
T.Buffer((seq_len, 2, 64), 'float16')
T.Buffer((seq_len, 2, 64), 'float16')
T.int32
--------------------------------
