=============embed=============
fused_dequantize_take2
T.Buffer((151936, 256), 'uint32')
T.Buffer((151936, 64), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 2048), 'float16')
--------------------------------
=============prefill=============
rms_norm1
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(2048),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fused_dequantize1_fused_NT_matmul7_add2
T.Buffer((T.int64(6144), T.int64(256)), 'uint32')
T.Buffer((T.int64(6144), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(6144),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(6144)), 'float16')
--------------------------------
reshape9
T.Buffer((T.int64(1), seq_len, T.int64(6144)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(48), T.int64(128)), 'float16')
--------------------------------
reshape10
T.Buffer((T.int64(1), seq_len, T.int64(48), T.int64(128)), 'float16')
T.Buffer((seq_len, T.int64(48), T.int64(128)), 'float16')
--------------------------------
reshape11
T.Buffer((seq_len, T.int64(16), T.int64(128)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(16), T.int64(128)), 'float16')
--------------------------------
reshape12
T.Buffer((T.int64(1), seq_len, T.int64(16), T.int64(128)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul8
T.Buffer((T.int64(2048), T.int64(256)), 'uint32')
T.Buffer((T.int64(2048), T.int64(64)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
fuse_add_norm_prefill
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((2048,), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
--------------------------------
reshape13
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((seq_len, T.int64(2048)), 'float16')
--------------------------------
fused_NT_matmul2_cast
T.Buffer((batch_size, T.int64(2048)), 'float16')
T.Buffer((T.int64(60), T.int64(2048)), 'float16')
T.Buffer((batch_size, T.int64(60)))
--------------------------------
fused_softmax_cast1
T.Buffer((batch_size, T.int64(60)))
T.Buffer((batch_size, T.int64(60)), 'float16')
--------------------------------
top4_softmax
T.Buffer((batch_size, 60), 'float16')
T.Buffer((batch_size, 4), 'float16')
T.Buffer((batch_size, 4), 'int32')
--------------------------------
fused_expert_mask_transpose
T.Buffer((batch_size, T.int64(4)), 'int32')
T.Buffer((T.int64(60), batch_size), 'int32')
--------------------------------
reshape5
T.Buffer((T.int64(60), batch_size), 'int32')
T.Buffer((batch_size * T.int64(60),), 'int32')
--------------------------------
gpu_2d_continuous_cumsum
T.Buffer((m, n), 'int32')
T.Buffer((m, n), 'int32')
--------------------------------
get_indices
T.Buffer((cumsum_len,), 'int32')
T.Buffer((batch_size, 4), 'int32')
T.Buffer((batch_size * 4,), 'int32')
T.Buffer((batch_size * 4,), 'int32')
--------------------------------
get_expert_instance_indptr
T.Buffer((batch_size * T.int64(60),), 'int32')
T.Buffer((61,), 'int32')
T.int64
--------------------------------
take
T.Buffer((batch_size, T.int64(2048)), 'float16')
T.Buffer((batch_size * T.int64(4),), 'int32')
T.Buffer((batch_size * T.int64(4), T.int64(2048)), 'float16')
--------------------------------
dequantize_group_gemm
T.Buffer((B, 2048), 'float16')
T.Buffer((60, 2816, 256), 'uint32')
T.Buffer((60, 2816, 64), 'float16')
T.Buffer((61,), 'int32')
T.Buffer((B, 2816), 'float16')
--------------------------------
fused_split_silu_multiply
T.Buffer((batch_size * T.int64(4), T.int64(2816)), 'float16')
T.Buffer((batch_size * T.int64(4), T.int64(1408)), 'float16')
T.int64
--------------------------------
dequantize_group_gemm1
T.Buffer((B, 1408), 'float16')
T.Buffer((60, 2048, 176), 'uint32')
T.Buffer((60, 2048, 44), 'float16')
T.Buffer((61,), 'int32')
T.Buffer((B, 2048), 'float16')
--------------------------------
scatter_output
T.Buffer((indices_len, 2048), 'float16')
T.Buffer((indices_len,), 'int32')
T.Buffer((indices_len, 2048), 'float16')
--------------------------------
reshape6
T.Buffer((batch_size, T.int64(4)), 'float16')
T.Buffer((batch_size, T.int64(4), T.int64(1)), 'float16')
--------------------------------
reshape7
T.Buffer((batch_size * T.int64(4), T.int64(2048)), 'float16')
T.Buffer((batch_size, T.int64(4), T.int64(2048)), 'float16')
--------------------------------
fused_multiply1_sum
T.Buffer((batch_size, T.int64(4), T.int64(2048)), 'float16')
T.Buffer((batch_size, T.int64(4), T.int64(1)), 'float16')
T.Buffer((batch_size, T.int64(2048)), 'float16')
--------------------------------
fused_dequantize3_NT_matmul3
T.Buffer((T.int64(11264), T.int64(256)), 'uint32')
T.Buffer((T.int64(11264), T.int64(64)), 'float16')
T.Buffer((batch_size, T.int64(2048)), 'float16')
T.Buffer((batch_size, T.int64(11264)), 'float16')
--------------------------------
fused_split1_silu1_multiply2
T.Buffer((batch_size, T.int64(11264)), 'float16')
T.Buffer((batch_size, T.int64(5632)), 'float16')
--------------------------------
fused_NT_matmul5_tir_sigmoid
T.Buffer((batch_size, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((batch_size, T.int64(1)), 'float16')
--------------------------------
fused_dequantize4_fused_NT_matmul4_multiply3_add1
T.Buffer((T.int64(2048), T.int64(704)), 'uint32')
T.Buffer((T.int64(2048), T.int64(176)), 'float16')
T.Buffer((batch_size, T.int64(5632)), 'float16')
T.Buffer((batch_size, T.int64(1)), 'float16')
T.Buffer((batch_size, T.int64(2048)), 'float16')
T.Buffer((batch_size, T.int64(2048)), 'float16')
--------------------------------
reshape14
T.Buffer((seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
--------------------------------
index
T.Buffer((T.int64(1), seq_len, T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize_fused_NT_matmul16_cast6
T.Buffer((T.int64(151936), T.int64(256)), 'uint32')
T.Buffer((T.int64(151936), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), 'float32')
--------------------------------
=============decode=============
rms_norm2
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(2048),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize1_fused_NT_matmul10_add3
T.Buffer((T.int64(6144), T.int64(256)), 'uint32')
T.Buffer((T.int64(6144), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(6144),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(6144)), 'float16')
--------------------------------
fused_reshape15_reshape16
T.Buffer((T.int64(1), T.int64(1), T.int64(6144)), 'float16')
T.Buffer((T.int64(1), T.int64(48), T.int64(128)), 'float16')
--------------------------------
fused_reshape17_reshape18
T.Buffer((T.int64(1), T.int64(16), T.int64(128)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul11
T.Buffer((T.int64(2048), T.int64(256)), 'uint32')
T.Buffer((T.int64(2048), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fuse_add_norm_prefill
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((2048,), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
T.Buffer((1, seq_len, 2048), 'float16')
--------------------------------
fused_reshape19
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_NT_matmul12_cast4
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(60), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(60)), 'float32')
--------------------------------
fused_softmax1_cast5
T.Buffer((T.int64(1), T.int64(60)), 'float32')
T.Buffer((T.int64(1), T.int64(60)), 'float16')
--------------------------------
top4_softmax
T.Buffer((batch_size, 60), 'float16')
T.Buffer((batch_size, 4), 'float16')
T.Buffer((batch_size, 4), 'int32')
--------------------------------
fused_reshape20
T.Buffer((T.int64(1), T.int64(4)), 'float16')
T.Buffer((T.int64(1), T.int64(4), T.int64(1)), 'float16')
--------------------------------
moe_dequantize_gemv
T.Buffer((1, 2048), 'float16')
T.Buffer((60, 2816, 256), 'uint32')
T.Buffer((60, 2816, 64), 'float16')
T.Buffer((1, 4), 'int32')
T.Buffer((4, 2816), 'float16')
--------------------------------
fused_split2_silu2_multiply4
T.Buffer((T.int64(4), T.int64(2816)), 'float16')
T.Buffer((T.int64(4), T.int64(1408)), 'float16')
--------------------------------
moe_dequantize_gemv1
T.Buffer((4, 1408), 'float16')
T.Buffer((60, 2048, 176), 'uint32')
T.Buffer((60, 2048, 44), 'float16')
T.Buffer((1, 4), 'int32')
T.Buffer((4, 2048), 'float16')
--------------------------------
reshape21
T.Buffer((T.int64(4), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(4), T.int64(2048)), 'float16')
--------------------------------
fused_multiply5_sum1
T.Buffer((T.int64(1), T.int64(4), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(4), T.int64(1)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize3_NT_matmul13
T.Buffer((T.int64(11264), T.int64(256)), 'uint32')
T.Buffer((T.int64(11264), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(11264)), 'float16')
--------------------------------
fused_split3_silu3_multiply6
T.Buffer((T.int64(1), T.int64(11264)), 'float16')
T.Buffer((T.int64(1), T.int64(5632)), 'float16')
--------------------------------
fused_NT_matmul15_tir_sigmoid1
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1)), 'float16')
--------------------------------
fused_dequantize4_fused_NT_matmul14_multiply7_add4
T.Buffer((T.int64(2048), T.int64(704)), 'uint32')
T.Buffer((T.int64(2048), T.int64(176)), 'float16')
T.Buffer((T.int64(1), T.int64(5632)), 'float16')
T.Buffer((T.int64(1), T.int64(1)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
--------------------------------
reshape22
T.Buffer((T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
--------------------------------
fused_dequantize_fused_NT_matmul16_cast6
T.Buffer((T.int64(151936), T.int64(256)), 'uint32')
T.Buffer((T.int64(151936), T.int64(64)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), 'float32')
--------------------------------
=============attn_kernels=============
batch_decode_paged_kv
T.int32
T.Buffer((B, 16, 128), 'float16')
T.Buffer((max_num_pages, 2, 16, 16, 128), 'float16')
T.Buffer((B + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B, 16, 128), 'float16')
T.Buffer((B, 16))
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_paged_kv
T.int32
T.Buffer((total_len, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((max_num_pages, 2, 16, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((total_len,), 'int32')
T.Buffer((total_len, 16, 128), 'float16')
T.Buffer((total_len, 16))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_ragged_kv
T.Buffer((qo_len, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((kv_len, 16, 128), 'float16')
T.Buffer((kv_len, 16, 128), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((qo_len,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((qo_len, 16, 128), 'float16')
T.Buffer((qo_len, 16))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
fused_rope
T.Buffer((seq_len, 48, 128), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 16, 128), 'float16')
T.Buffer((seq_len, 16, 128), 'float16')
T.Buffer((seq_len, 16, 128), 'float16')
T.int32
--------------------------------
