=============embed=============
fused_dequantize_take1
T.Buffer((32064, 384), 'uint32')
T.Buffer((32064, 96), 'float16')
T.Buffer((seq_len,), 'int32')
T.Buffer((seq_len, 3072), 'float16')
--------------------------------
=============prefill=============
rms_norm1
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
T.Buffer((T.int64(3072),), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
--------------------------------
fused_dequantize1_NT_matmul5
T.Buffer((T.int64(9216), T.int64(384)), 'uint32')
T.Buffer((T.int64(9216), T.int64(96)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(9216)), 'float16')
--------------------------------
reshape4
T.Buffer((T.int64(1), seq_len, T.int64(9216)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(96), T.int64(96)), 'float16')
--------------------------------
reshape5
T.Buffer((T.int64(1), seq_len, T.int64(96), T.int64(96)), 'float16')
T.Buffer((seq_len, T.int64(96), T.int64(96)), 'float16')
--------------------------------
reshape6
T.Buffer((seq_len, T.int64(32), T.int64(96)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(32), T.int64(96)), 'float16')
--------------------------------
reshape7
T.Buffer((T.int64(1), seq_len, T.int64(32), T.int64(96)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul6
T.Buffer((T.int64(3072), T.int64(384)), 'uint32')
T.Buffer((T.int64(3072), T.int64(96)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
--------------------------------
fuse_add_norm_prefill
T.Buffer((1, seq_len, 3072), 'float16')
T.Buffer((1, seq_len, 3072), 'float16')
T.Buffer((3072,), 'float16')
T.Buffer((1, seq_len, 3072), 'float16')
T.Buffer((1, seq_len, 3072), 'float16')
--------------------------------
fused_dequantize3_NT_matmul7
T.Buffer((T.int64(16384), T.int64(384)), 'uint32')
T.Buffer((T.int64(16384), T.int64(96)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(16384)), 'float16')
--------------------------------
fused_split1_silu1_multiply1
T.Buffer((T.int64(1), seq_len, T.int64(16384)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(8192)), 'float16')
--------------------------------
fused_dequantize4_NT_matmul8
T.Buffer((T.int64(3072), T.int64(1024)), 'uint32')
T.Buffer((T.int64(3072), T.int64(256)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(8192)), 'float16')
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
--------------------------------
index
T.Buffer((T.int64(1), seq_len, T.int64(3072)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
--------------------------------
fused_dequantize5_fused_NT_matmul14_cast2
T.Buffer((vocab_size, T.int64(384)), 'uint32')
T.Buffer((vocab_size, T.int64(96)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
T.Buffer((T.int64(1), T.int64(1), vocab_size))
--------------------------------
=============decode=============
rms_norm2
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
T.Buffer((T.int64(3072),), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
--------------------------------
fused_dequantize1_NT_matmul10
T.Buffer((T.int64(9216), T.int64(384)), 'uint32')
T.Buffer((T.int64(9216), T.int64(96)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(9216)), 'float16')
--------------------------------
fused_reshape8_reshape9
T.Buffer((T.int64(1), T.int64(1), T.int64(9216)), 'float16')
T.Buffer((T.int64(1), T.int64(96), T.int64(96)), 'float16')
--------------------------------
fused_reshape10_reshape11
T.Buffer((T.int64(1), T.int64(32), T.int64(96)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
--------------------------------
fused_dequantize2_NT_matmul11
T.Buffer((T.int64(3072), T.int64(384)), 'uint32')
T.Buffer((T.int64(3072), T.int64(96)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
--------------------------------
fuse_add_norm_prefill
T.Buffer((1, seq_len, 3072), 'float16')
T.Buffer((1, seq_len, 3072), 'float16')
T.Buffer((3072,), 'float16')
T.Buffer((1, seq_len, 3072), 'float16')
T.Buffer((1, seq_len, 3072), 'float16')
--------------------------------
fused_dequantize3_NT_matmul12
T.Buffer((T.int64(16384), T.int64(384)), 'uint32')
T.Buffer((T.int64(16384), T.int64(96)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(16384)), 'float16')
--------------------------------
fused_split2_silu2_multiply2
T.Buffer((T.int64(1), T.int64(1), T.int64(16384)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), 'float16')
--------------------------------
fused_dequantize4_NT_matmul13
T.Buffer((T.int64(3072), T.int64(1024)), 'uint32')
T.Buffer((T.int64(3072), T.int64(256)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
--------------------------------
fused_dequantize5_fused_NT_matmul14_cast2
T.Buffer((vocab_size, T.int64(384)), 'uint32')
T.Buffer((vocab_size, T.int64(96)), 'float16')
T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), 'float16')
T.Buffer((T.int64(1), T.int64(1), vocab_size))
--------------------------------
=============attn_kernels=============
batch_decode_paged_kv
T.int32
T.Buffer((B, 32, 96), 'float16')
T.Buffer((max_num_pages, 2, 32, 16, 96), 'float16')
T.Buffer((B + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B,), 'int32')
T.Buffer((B, 32, 96), 'float16')
T.Buffer((B, 32))
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_paged_kv
T.int32
T.Buffer((total_len, 32, 96), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((max_num_pages, 2, 32, 16, 96), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((nnz_pages,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((total_len,), 'int32')
T.Buffer((total_len, 32, 96), 'float16')
T.Buffer((total_len, 32))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
batch_prefill_ragged_kv
T.Buffer((qo_len, 32, 96), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((kv_len, 32, 96), 'float16')
T.Buffer((kv_len, 32, 96), 'float16')
T.Buffer((batch_size + 1,), 'int32')
T.Buffer((qo_len,), 'int32')
T.Buffer((batch_size,), 'int32')
T.Buffer((qo_len, 32, 96), 'float16')
T.Buffer((qo_len, 32))
T.int32
T.int32
T.float32
T.float32
T.float32
--------------------------------
