=============embed=============
fused_dequantize_take1
T.Buffer((vocab_size, 288), 'uint32')
T.Buffer((vocab_size, 72), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 2304), 'float16')
--------------------------------
=============prefill=============
multiply3
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
--------------------------------
rms_norm1
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(2304),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
--------------------------------
fused_dequantize1_NT_matmul5
T.Buffer((T.int64(4096), T.int64(288)), 'uint32')
T.Buffer((T.int64(4096), T.int64(72)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(4096)), 'float16')
--------------------------------
reshape4
T.Buffer((T.int64(1), seq_len, T.int64(4096)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(16), T.int64(256)), 'float16')
--------------------------------
reshape5
T.Buffer((T.int64(1), seq_len, T.int64(16), T.int64(256)), 'float16')
T.Buffer((seq_len, T.int64(16), T.int64(256)), 'float16')
--------------------------------
reshape6
T.Buffer((seq_len, T.int64(8), T.int64(256)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(8), T.int64(256)), 'float16')
--------------------------------
reshape7
T.Buffer((T.int64(1), seq_len, T.int64(8), T.int64(256)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul6
T.Buffer((T.int64(2304), T.int64(256)), 'uint32')
T.Buffer((T.int64(2304), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
--------------------------------
fused_rms_norm1_add1
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(2304),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
--------------------------------
fused_dequantize3_NT_matmul7
T.Buffer((T.int64(18432), T.int64(288)), 'uint32')
T.Buffer((T.int64(18432), T.int64(72)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(18432)), 'float16')
--------------------------------
fused_split1_gelu_tanh1_multiply4
T.Buffer((T.int64(1), seq_len, T.int64(18432)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(9216)), 'float16')
--------------------------------
fused_dequantize4_NT_matmul8
T.Buffer((T.int64(2304), T.int64(1152)), 'uint32')
T.Buffer((T.int64(2304), T.int64(288)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(9216)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
--------------------------------
index
T.Buffer((T.int64(1), seq_len, T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
--------------------------------
fused_dequantize_fused_NT_matmul14_divide2_tir_tanh2_multiply8
T.Buffer((vocab_size, T.int64(288)), 'uint32')
T.Buffer((vocab_size, T.int64(72)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), vocab_size))
--------------------------------
=============decode=============
multiply6
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
--------------------------------
rms_norm2
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(2304),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
--------------------------------
fused_dequantize1_NT_matmul10
T.Buffer((T.int64(4096), T.int64(288)), 'uint32')
T.Buffer((T.int64(4096), T.int64(72)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 'float16')
--------------------------------
fused_reshape8_reshape9
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), 'float16')
T.Buffer((T.int64(1), T.int64(16), T.int64(256)), 'float16')
--------------------------------
fused_reshape10_reshape11
T.Buffer((T.int64(1), T.int64(8), T.int64(256)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul11
T.Buffer((T.int64(2304), T.int64(256)), 'uint32')
T.Buffer((T.int64(2304), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
--------------------------------
fused_rms_norm2_add2
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(2304),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
--------------------------------
fused_dequantize3_NT_matmul12
T.Buffer((T.int64(18432), T.int64(288)), 'uint32')
T.Buffer((T.int64(18432), T.int64(72)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(18432)), 'float16')
--------------------------------
fused_split2_gelu_tanh2_multiply7
T.Buffer((T.int64(1), T.int64(1), T.int64(18432)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(9216)), 'float16')
--------------------------------
fused_dequantize4_NT_matmul13
T.Buffer((T.int64(2304), T.int64(1152)), 'uint32')
T.Buffer((T.int64(2304), T.int64(288)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(9216)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
--------------------------------
fused_dequantize_fused_NT_matmul14_divide2_tir_tanh2_multiply8
T.Buffer((vocab_size, T.int64(288)), 'uint32')
T.Buffer((vocab_size, T.int64(72)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2304)), 'float16')
T.Buffer((T.int64(1), T.int64(1), vocab_size))
--------------------------------
=============attn_kernels=============
batch_decode_paged_kv
T.int32
T.Buffer((B, 8, 256), 'float16')
T.Buffer((max_num_pages, 2, 4, 16, 256), 'float16')
T.Buffer((B + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B, 8, 256), 'float16')
T.Buffer((B, 8))
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_paged_kv
T.int32
T.Buffer((total_len, 8, 256), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((max_num_pages, 2, 4, 16, 256), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((total_len,), 'int32')
T.Buffer((total_len, 8, 256), 'float16')
T.Buffer((total_len, 8))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_ragged_kv
T.Buffer((qo_len, 8, 256), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((kv_len, 4, 256), 'float16')
T.Buffer((kv_len, 4, 256), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((qo_len,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((qo_len, 8, 256), 'float16')
T.Buffer((qo_len, 8))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
fused_rope
T.Buffer((seq_len, 16, 256), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 8, 256), 'float16')
T.Buffer((seq_len, 4, 256), 'float16')
T.Buffer((seq_len, 4, 256), 'float16')
T.int32
--------------------------------
