=============embed=============
fused_dequantize_take1
T.Buffer((92544, 256), 'uint32')
T.Buffer((92544, 64), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 2048), 'float16')
--------------------------------
=============prefill=============
rms_norm1
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(2048),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fused_dequantize1_NT_matmul5
T.Buffer((T.int64(4096), T.int64(256)), 'uint32')
T.Buffer((T.int64(4096), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(4096)), 'float16')
--------------------------------
reshape4
T.Buffer((T.int64(1), seq_len, T.int64(4096)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(32), T.int64(128)), 'float16')
--------------------------------
reshape5
T.Buffer((T.int64(1), seq_len, T.int64(32), T.int64(128)), 'float16')
T.Buffer((seq_len, T.int64(32), T.int64(128)), 'float16')
--------------------------------
reshape6
T.Buffer((seq_len, T.int64(16), T.int64(128)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(16), T.int64(128)), 'float16')
--------------------------------
reshape7
T.Buffer((T.int64(1), seq_len, T.int64(16), T.int64(128)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul6
T.Buffer((T.int64(2048), T.int64(256)), 'uint32')
T.Buffer((T.int64(2048), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fuse_add_norm_prefill
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((2048,), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
--------------------------------
fused_dequantize3_NT_matmul7
T.Buffer((T.int64(16384), T.int64(256)), 'uint32')
T.Buffer((T.int64(16384), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(16384)), 'float16')
--------------------------------
fused_split1_silu1_multiply1
T.Buffer((T.int64(1), seq_len, T.int64(16384)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(8192)), 'float16')
--------------------------------
fused_dequantize4_NT_matmul8
T.Buffer((T.int64(2048), T.int64(1024)), 'uint32')
T.Buffer((T.int64(2048), T.int64(256)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(8192)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
index
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize_fused_NT_matmul14_cast2
T.Buffer((T.int64(92544), T.int64(256)), 'uint32')
T.Buffer((T.int64(92544), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(92544)), 'float32')
--------------------------------
=============decode=============
rms_norm2
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(2048),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize1_NT_matmul10
T.Buffer((T.int64(4096), T.int64(256)), 'uint32')
T.Buffer((T.int64(4096), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 'float16')
--------------------------------
fused_reshape8_reshape9
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 'float16')
T.Buffer((T.int64(1), T.int64(32), T.int64(128)), 'float16')
--------------------------------
fused_reshape10_reshape11
T.Buffer((T.int64(1), T.int64(16), T.int64(128)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul11
T.Buffer((T.int64(2048), T.int64(256)), 'uint32')
T.Buffer((T.int64(2048), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fuse_add_norm_prefill
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((2048,), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
--------------------------------
fused_dequantize3_NT_matmul12
T.Buffer((T.int64(16384), T.int64(256)), 'uint32')
T.Buffer((T.int64(16384), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(16384)), 'float16')
--------------------------------
fused_split2_silu2_multiply2
T.Buffer((T.int64(1), T.int64(1), T.int64(16384)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), 'float16')
--------------------------------
fused_dequantize4_NT_matmul13
T.Buffer((T.int64(2048), T.int64(1024)), 'uint32')
T.Buffer((T.int64(2048), T.int64(256)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize_fused_NT_matmul14_cast2
T.Buffer((T.int64(92544), T.int64(256)), 'uint32')
T.Buffer((T.int64(92544), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(92544)), 'float32')
--------------------------------
=============attn_kernels=============
batch_decode_paged_kv
T.int32
T.Buffer((B, 16, 128), 'float16')
T.Buffer((max_num_pages, 2, 8, 16, 128), 'float16')
T.Buffer((B + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B, 16, 128), 'float16')
T.Buffer((B, 16))
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_paged_kv
T.int32
T.Buffer((total_len, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((max_num_pages, 2, 8, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((total_len,), 'int32')
T.Buffer((total_len, 16, 128), 'float16')
T.Buffer((total_len, 16))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_ragged_kv
T.Buffer((qo_len, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((kv_len, 8, 128), 'float16')
T.Buffer((kv_len, 8, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((qo_len,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((qo_len, 16, 128), 'float16')
T.Buffer((qo_len, 16))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
fused_rope
T.Buffer((seq_len, 32, 128), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 16, 128), 'float16')
T.Buffer((seq_len, 8, 128), 'float16')
T.Buffer((seq_len, 8, 128), 'float16')
T.int32
--------------------------------
