# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func(private=True)
    def argsort_thrust(var_probs: T.handle, var_lv: T.handle, var_topk_gpu_v1: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        data_buf = T.match_buffer(var_probs, (batch_size, vocab_size), align=8)
        workspace_buf = T.match_buffer(var_lv, (T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12),), "uint8", align=8)
        indices_buf = T.match_buffer(var_topk_gpu_v1, (batch_size, vocab_size), "int32", align=8)
        # with T.block("root"):
        value_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        with T.block("topk_gpu"):
            T.reads()
            T.writes()
            T.call_packed("tvm.contrib.thrust.sort", T.tvm_stack_make_array(data_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0.0), T.int64(0)), T.tvm_stack_make_array(value_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0.0), T.int64(0)), T.tvm_stack_make_array(indices_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, 0, T.int64(0)), 0, T.tvm_stack_make_array(workspace_buf.data, T.tvm_stack_make_shape(T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12)), 0, 1, T.uint8(0), T.int64(0)))

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 32, 96), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 32, 16, 96), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 32, 96), "float16")
        lse = T.match_buffer(lse_handle, (B, 32))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.14724444602590309)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(32, thread="blockIdx.y"):
                for ty in T.thread_binding(1, thread="threadIdx.y"):
                    for tx in T.thread_binding(24, thread="threadIdx.x"):
                        for tz in T.thread_binding(21, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, fused_by_bz // 32 + ty + fused_by_bz % 32, tx * 4 - 48:tx * 4 - 48 + 100])
                                T.writes(output[bx, fused_by_bz % 32 + fused_by_bz // 32 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 32 + fused_by_bz // 32 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((42, 96), "float16", scope="shared")
                                V_smem = T.alloc_buffer((42, 96), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((21, 1, 96), scope="shared")
                                md_allreduce = T.alloc_buffer((21, 1, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((2,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 32
                                bz: T.int32 = fused_by_bz // 32
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 48, Q[bx, by + bz + ty, tx * 4 + vec + 48] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 48]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 96) / T.float32(96.0))}), Q[bx, by + bz + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 41) // 42):
                                    tile_start_s: T.int32 = (tz + ty) * 2
                                    tile_start_g: T.int32 = (iterator * 21 + tz + ty) * 2
                                    for j in range(2):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 48, pages[page_no, 0, by, page_offset, tx * 4 + vec + 48] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 48]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 96) / T.float32(96.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 21 + tz) * 2 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(2):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(21):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 32, 96), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 32, 16, 96), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 32, 96), "float16")
        lse = T.match_buffer(lse_handle, (B, 32))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.14724444602590309)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(32, thread="blockIdx.y"):
                for ty in T.thread_binding(1, thread="threadIdx.y"):
                    for tx in T.thread_binding(24, thread="threadIdx.x"):
                        for tz in T.thread_binding(21, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, fused_by_bz // 32 + ty + fused_by_bz % 32, tx * 4 - 48:tx * 4 - 48 + 100])
                                T.writes(output[bx, fused_by_bz % 32 + fused_by_bz // 32 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 32 + fused_by_bz // 32 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((42, 96), "float16", scope="shared")
                                V_smem = T.alloc_buffer((42, 96), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((21, 1, 96), scope="shared")
                                md_allreduce = T.alloc_buffer((21, 1, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((2,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 32
                                bz: T.int32 = fused_by_bz // 32
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 48, Q[bx, by + bz + ty, tx * 4 + vec + 48] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 48]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 96) / T.float32(96.0))}), Q[bx, by + bz + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 41) // 42):
                                    tile_start_s: T.int32 = (tz + ty) * 2
                                    tile_start_g: T.int32 = (iterator * 21 + tz + ty) * 2
                                    for j in range(2):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 48, pages[page_no, 0, by, page_offset, tx * 4 + vec + 48] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 48]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 96) / T.float32(96.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 21 + tz) * 2 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(2):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(21):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 32, 96), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 32, 16, 96), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 32, 96), "float16")
        lse = T.match_buffer(var_lse, (total_len, 32))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(32, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 96), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(3):
                                                    for lj_1_1 in T.vectorized(2):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 3):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(96, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, q[cur_L, cur_H_qo, j + 48] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 48]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 3):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(96, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, pages[page_no, 0, by, page_offset, j + 48] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 48]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 3):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(96, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:96], K_smem[0:32, 0:96])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(6):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(96, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.14724444602590309)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:96])
                                            T.writes(O_local[0:32, 0:96])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(3):
                                                            for lj_1_1_init in T.vectorized(2):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 6 + lj_1_0_init * 2 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(3):
                                                                for lj_1_1 in T.vectorized(2):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(3):
                                                    for lj_1_1 in T.vectorized(2):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                            cur_H_qo: T.int32 = by
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 32, 96), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 32, 16, 96), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 32, 96), "float16")
        lse = T.match_buffer(var_lse, (total_len, 32))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(32, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 96), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(3):
                                                    for lj_1_1 in T.vectorized(2):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 3):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(96, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, q[cur_L, cur_H_qo, j + 48] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 48]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 3):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(96, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, pages[page_no, 0, by, page_offset, j + 48] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 48]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 3):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(96, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:96], K_smem[0:32, 0:96])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(6):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(96, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.14724444602590309)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:96])
                                            T.writes(O_local[0:32, 0:96])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(3):
                                                            for lj_1_1_init in T.vectorized(2):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 6 + lj_1_0_init * 2 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(3):
                                                                for lj_1_1 in T.vectorized(2):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(3):
                                                    for lj_1_1 in T.vectorized(2):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                            cur_H_qo: T.int32 = by
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 32, 96), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 32, 96), "float16")
        v = T.match_buffer(var_v, (kv_len, 32, 96), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 32, 96), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 32))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(32, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            K_smem = T.alloc_buffer((96, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 96), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(3):
                                                    for lj_1_1 in T.vectorized(2):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 3):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(96, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, q[cur_L, cur_H_qo, j + 48] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 48]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 3):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(96, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 48:j - 48 + 97])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, k[L_kv_base + cur_L, by, j + 48] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 48]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 3):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(96, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:96], K_smem[0:96, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(6):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(96, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.14724444602590309)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:96])
                                            T.writes(O_local[0:32, 0:96])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(3):
                                                            for lj_1_1_init in T.vectorized(2):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 6 + lj_1_0_init * 2 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(3):
                                                                for lj_1_1 in T.vectorized(2):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(3):
                                                    for lj_1_1 in T.vectorized(2):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1_0 * 2 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                            cur_H_qo: T.int32 = by
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 32, 96), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 32, 96), "float16")
        v = T.match_buffer(var_v, (kv_len, 32, 96), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 32, 96), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 32))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(32, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 96), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 6):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(6):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 96)
                                                        j = T.axis.spatial(96, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 96)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, q[cur_L, cur_H_qo, j + 48] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 48]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(6):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 96)
                                                            j = T.axis.spatial(96, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 96)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, k[cur_L, by, j + 48] * T.float16(-1.0), k[cur_L, by, j - 48]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:96], K_smem[0:32, 0:96])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(12, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(96, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.14724444602590309)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:96])
                                            T.writes(O_local[0:32, 0:96])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 6):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 6 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 6):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 6):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func(private=True)
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int64(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                    if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                        if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)))
        temp_max = T.alloc_buffer((batch_size, num_chunks))
        temp_sum = T.alloc_buffer((batch_size, num_chunks))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("pad"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2])
                T.writes(A_pad[v0, v1, v2])
                A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("max"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(A_pad[v0, v1, v2])
                T.writes(temp_max[v0, v1])
                with T.init():
                    temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("sum_exp"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1])
                T.writes(temp_sum[v0, v1])
                with T.init():
                    temp_sum[v0, v1] = T.float32(0.0)
                temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
            with T.block("log"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1])
                T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1])
                chunked_max[v0, v1] = temp_max[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 32, 16, 96), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding(batch_size * 3, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 3072
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 96 % 32
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 96
                    if bhd_o * 1024 + bhd_i < batch_size * 32 * 96:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 32, page_size, 96), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding(copy_length * T.int64(3), thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(32, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(96))))
                    vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(96)) // T.int64(96))
                    vd = T.axis.spatial(96, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(96)))
                    T.where(b * T.int64(1024) + T.Cast("int64", t) < copy_length * T.int64(32) * T.int64(96))
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func(private=True)
    def cumsum(var_sorted_probs: T.handle, var_lv1: T.handle, var_exclusive_scan_thrust: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        data_buf = T.match_buffer(var_sorted_probs, (batch_size, vocab_size), align=8)
        workspace_buf = T.match_buffer(var_lv1, (T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12),), "uint8", align=8)
        output_buf = T.match_buffer(var_exclusive_scan_thrust, (batch_size, vocab_size), align=8)
        with T.block("exclusive_scan_thrust"):
            T.reads()
            T.writes()
            T.call_packed("tvm.contrib.thrust.sum_scan", T.tvm_stack_make_array(data_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0.0), T.int64(0)), T.tvm_stack_make_array(output_buf.data, T.tvm_stack_make_shape(batch_size, vocab_size), 0, 2, T.float32(0.0), T.int64(0)), T.bool(False), T.tvm_stack_make_array(workspace_buf.data, T.tvm_stack_make_shape(T.int64(8) * (batch_size * vocab_size * T.int64(4)) + T.int64(8388608) + batch_size * vocab_size * T.int64(12)), 0, 1, T.uint8(0), T.int64(0)))

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for i in range(batch_size):
            with T.block("block"):
                vi = T.axis.spatial(batch_size, i)
                T.reads()
                T.writes(result[vi, 0])
                result[vi, 0] = value

    @T.prim_func(private=True)
    def fuse_add_norm_decode(pA: T.handle, pB: T.handle, C: T.Buffer((3072,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        A = T.match_buffer(pA, (batch_size, 1, 3072), "float16")
        B = T.match_buffer(pB, (batch_size, 1, 3072), "float16")
        O = T.match_buffer(pO, (batch_size, 1, 3072), "float16")
        add = T.match_buffer(pAdd, (batch_size, 1, 3072), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((3,), "float16", scope="local")
        sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
        for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for i in range(3):
                    with T.block("T_add"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(3072, i * 1024 + v_tx)
                        T.reads(A[bx, 0, h], B[bx, 0, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        v_ax1 = T.axis.spatial(1, 0)
                        h = T.axis.spatial(3072, i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[bx, v_ax1, h])
                        add[bx, v_ax1, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, bx, 0])
                    sum_local[tx, bx, 0] = T.float32(0.0)
                for v_i, _j in T.grid(3, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, bx, 0], add_local[i])
                        T.writes(sum_local[tx, bx, 0])
                        sum_local[tx, bx, 0] = sum_local[tx, bx, 0] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, bx, 0])
                        T.writes(sum_shared[bx, 0])
                        with T.init():
                            sum_shared[bx, 0] = T.float32(0.0)
                        sum_shared[bx, 0] = sum_shared[bx, 0] + sum_local[tx, bx, 0]
            for i in range(3):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(3072, i * 1024 + v_tx_2)
                        T.reads(sum_shared[bx, 0], add_local[h // 1024], C[h])
                        T.writes(O[bx, 0, h])
                        O[bx, 0, h] = T.Cast("float16", T.rsqrt(sum_shared[bx, 0] * T.float32(0.00032552083333333332) + T.float32(1.0000000000000001e-05)) * T.Cast("float32", add_local[h // 1024]) * T.Cast("float32", C[h]))

    @T.prim_func(private=True)
    def fuse_add_norm_prefill(pA: T.handle, pB: T.handle, C: T.Buffer((3072,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        A = T.match_buffer(pA, (1, seq_len, 3072), "float16")
        B = T.match_buffer(pB, (1, seq_len, 3072), "float16")
        O = T.match_buffer(pO, (1, seq_len, 3072), "float16")
        add = T.match_buffer(pAdd, (1, seq_len, 3072), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((3,), "float16", scope="local")
        sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
        sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
        for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for v_i in range(3):
                    with T.block("T_add"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(3072, v_i * 1024 + v_tx)
                        T.reads(A[0, bx, h], B[0, bx, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(3072, v_i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[0, bx, h])
                        add[0, bx, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, 0, bx])
                    sum_local[tx, 0, bx] = T.float32(0.0)
                for v_i, _j in T.grid(3, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, 0, bx], add_local[i])
                        T.writes(sum_local[tx, 0, bx])
                        sum_local[tx, 0, bx] = sum_local[tx, 0, bx] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, 0, bx])
                        T.writes(sum_shared[0, bx])
                        with T.init():
                            sum_shared[0, bx] = T.float32(0.0)
                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
            for v_i in range(3):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        v1 = T.axis.spatial(3072, v_i * 1024 + v_tx_2)
                        T.reads(sum_shared[0, bx], add_local[v1 // 1024], C[v1])
                        T.writes(O[0, bx, v1])
                        O[0, bx, v1] = T.Cast("float16", T.rsqrt(sum_shared[0, bx] * T.float32(0.00032552083333333332) + T.float32(1.0000000000000001e-05)) * T.Cast("float32", add_local[v1 // 1024]) * T.Cast("float32", C[v1]))

    @T.prim_func(private=True)
    def fused_dequantize1_NT_matmul(transformer_h_0_mixer_qkv_proj_q_weight4: T.Buffer((T.int64(9216), T.int64(384)), "uint32"), transformer_h_0_mixer_qkv_proj_q_scale4: T.Buffer((T.int64(9216), T.int64(96)), "float16"), p_rms_norm195: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm195 = T.match_buffer(p_rms_norm195, (batch_size, T.int64(1), T.int64(3072)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(9216)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(9216), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(9216), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(9216), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mixer_qkv_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mixer_qkv_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(9216), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mixer_qkv_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mixer_qkv_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(9216), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm195[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm195[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize1_NT_matmul10(transformer_h_0_mixer_qkv_proj_q_weight2: T.Buffer((T.int64(9216), T.int64(384)), "uint32"), transformer_h_0_mixer_qkv_proj_q_scale2: T.Buffer((T.int64(9216), T.int64(96)), "float16"), rms_norm65: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(9216)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(9216), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(9216), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(9216), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mixer_qkv_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mixer_qkv_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(9216), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mixer_qkv_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mixer_qkv_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(9216), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm65[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm65[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize1_NT_matmul5(transformer_h_0_mixer_qkv_proj_q_weight3: T.Buffer((T.int64(9216), T.int64(384)), "uint32"), transformer_h_0_mixer_qkv_proj_q_scale3: T.Buffer((T.int64(9216), T.int64(96)), "float16"), p_rms_norm130: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(3072)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(9216)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(9216), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(9216), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(9216), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mixer_qkv_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mixer_qkv_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(9216), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mixer_qkv_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mixer_qkv_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(9216), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm130[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul1(transformer_h_0_mixer_out_proj_q_weight4: T.Buffer((T.int64(3072), T.int64(384)), "uint32"), transformer_h_0_mixer_out_proj_q_scale4: T.Buffer((T.int64(3072), T.int64(96)), "float16"), p_reshape387: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape387 = T.match_buffer(p_reshape387, (batch_size, T.int64(1), T.int64(3072)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(3072)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(3072), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(3072), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(3072), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mixer_out_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mixer_out_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(3072), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mixer_out_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mixer_out_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(3072), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape387[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape387[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul11(transformer_h_0_mixer_out_proj_q_weight2: T.Buffer((T.int64(3072), T.int64(384)), "uint32"), transformer_h_0_mixer_out_proj_q_scale2: T.Buffer((T.int64(3072), T.int64(96)), "float16"), lv100: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(3072), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(3072), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(3072), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mixer_out_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mixer_out_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(3072), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mixer_out_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mixer_out_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(3072), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv100[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv100[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul6(transformer_h_0_mixer_out_proj_q_weight3: T.Buffer((T.int64(3072), T.int64(384)), "uint32"), transformer_h_0_mixer_out_proj_q_scale3: T.Buffer((T.int64(3072), T.int64(96)), "float16"), p_reshape259: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape259 = T.match_buffer(p_reshape259, (T.int64(1), seq_len, T.int64(3072)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(3072)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(3072), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(3072), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(3072), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mixer_out_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mixer_out_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(3072), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mixer_out_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mixer_out_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(3072), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape259[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape259[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul12(transformer_h_0_mlp_gate_up_proj_q_weight2: T.Buffer((T.int64(16384), T.int64(384)), "uint32"), transformer_h_0_mlp_gate_up_proj_q_scale2: T.Buffer((T.int64(16384), T.int64(96)), "float16"), rms_norm66: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(16384)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(16384), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(16384), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(16384), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mlp_gate_up_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mlp_gate_up_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(16384), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mlp_gate_up_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mlp_gate_up_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(16384), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm66[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm66[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul2(transformer_h_0_mlp_gate_up_proj_q_weight4: T.Buffer((T.int64(16384), T.int64(384)), "uint32"), transformer_h_0_mlp_gate_up_proj_q_scale4: T.Buffer((T.int64(16384), T.int64(96)), "float16"), p_rms_norm196: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm196 = T.match_buffer(p_rms_norm196, (batch_size, T.int64(1), T.int64(3072)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(16384)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(16384), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(16384), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(16384), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mlp_gate_up_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mlp_gate_up_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(16384), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mlp_gate_up_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mlp_gate_up_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(16384), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm196[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm196[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul7(transformer_h_0_mlp_gate_up_proj_q_weight3: T.Buffer((T.int64(16384), T.int64(384)), "uint32"), transformer_h_0_mlp_gate_up_proj_q_scale3: T.Buffer((T.int64(16384), T.int64(96)), "float16"), p_rms_norm131: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm131 = T.match_buffer(p_rms_norm131, (T.int64(1), seq_len, T.int64(3072)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(16384)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(16384), T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(16384), T.int64(3072)), "float16")
        for i0, i1 in T.grid(T.int64(16384), T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mlp_gate_up_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mlp_gate_up_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(16384), T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mlp_gate_up_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mlp_gate_up_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(16384), T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm131[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm131[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul13(transformer_h_0_mlp_down_proj_q_weight2: T.Buffer((T.int64(3072), T.int64(1024)), "uint32"), transformer_h_0_mlp_down_proj_q_scale2: T.Buffer((T.int64(3072), T.int64(256)), "float16"), lv101: T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(3072), T.int64(8192)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(3072), T.int64(8192)), "float16")
        for i0, i1 in T.grid(T.int64(3072), T.int64(8192)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mlp_down_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mlp_down_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(3072), T.int64(8192)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mlp_down_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mlp_down_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(3072), T.int64(8192)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv101[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv101[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul3(transformer_h_0_mlp_down_proj_q_weight4: T.Buffer((T.int64(3072), T.int64(1024)), "uint32"), transformer_h_0_mlp_down_proj_q_scale4: T.Buffer((T.int64(3072), T.int64(256)), "float16"), p_lv: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv = T.match_buffer(p_lv, (batch_size, T.int64(1), T.int64(8192)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(3072)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(3072), T.int64(8192)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(3072), T.int64(8192)), "float16")
        for i0, i1 in T.grid(T.int64(3072), T.int64(8192)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mlp_down_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mlp_down_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(3072), T.int64(8192)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mlp_down_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mlp_down_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(3072), T.int64(8192)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul8(transformer_h_0_mlp_down_proj_q_weight3: T.Buffer((T.int64(3072), T.int64(1024)), "uint32"), transformer_h_0_mlp_down_proj_q_scale3: T.Buffer((T.int64(3072), T.int64(256)), "float16"), p_lv33: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv33 = T.match_buffer(p_lv33, (T.int64(1), seq_len, T.int64(8192)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(3072)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(3072), T.int64(8192)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(3072), T.int64(8192)), "float16")
        for i0, i1 in T.grid(T.int64(3072), T.int64(8192)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_h_0_mlp_down_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_h_0_mlp_down_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(3072), T.int64(8192)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], transformer_h_0_mlp_down_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * transformer_h_0_mlp_down_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(3072), T.int64(8192)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv33[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv33[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize5_fused_NT_matmul14_cast2(p_lm_head_q_weight2: T.handle, p_lm_head_q_scale2: T.handle, rms_norm129: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        vocab_size = T.int64()
        lm_head_q_weight2 = T.match_buffer(p_lm_head_q_weight2, (vocab_size, T.int64(384)), "uint32")
        lm_head_q_scale2 = T.match_buffer(p_lm_head_q_scale2, (vocab_size, T.int64(96)), "float16")
        compute_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), vocab_size))
        # with T.block("root"):
        compute = T.alloc_buffer((vocab_size, T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((vocab_size, T.int64(3072)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), vocab_size), "float16")
        for i0, i1 in T.grid(vocab_size, T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lm_head_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lm_head_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(vocab_size, T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], lm_head_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * lm_head_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), vocab_size, T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm129[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm129[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), vocab_size):
            with T.block("compute_1"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute_intermediate_intermediate[v_i0, v_i1, v_i2])
                compute_intermediate_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2])

    @T.prim_func(private=True)
    def fused_dequantize5_fused_NT_matmul4_cast(p_lm_head_q_weight4: T.handle, p_lm_head_q_scale4: T.handle, p_rms_norm259: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        vocab_size = T.int64()
        lm_head_q_weight4 = T.match_buffer(p_lm_head_q_weight4, (vocab_size, T.int64(384)), "uint32")
        lm_head_q_scale4 = T.match_buffer(p_lm_head_q_scale4, (vocab_size, T.int64(96)), "float16")
        batch_size = T.int64()
        rms_norm259 = T.match_buffer(p_rms_norm259, (batch_size, T.int64(1), T.int64(3072)), "float16")
        compute_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), vocab_size))
        # with T.block("root"):
        compute = T.alloc_buffer((vocab_size, T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((vocab_size, T.int64(3072)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), vocab_size), "float16")
        for i0, i1 in T.grid(vocab_size, T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lm_head_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lm_head_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(vocab_size, T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], lm_head_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * lm_head_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), vocab_size, T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm259[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm259[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for i0, i1, i2 in T.grid(batch_size, T.int64(1), vocab_size):
            with T.block("compute_1"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute_intermediate_intermediate[v_i0, v_i1, v_i2])
                compute_intermediate_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2])

    @T.prim_func(private=True)
    def fused_dequantize5_fused_NT_matmul9_cast1(p_lm_head_q_weight3: T.handle, p_lm_head_q_scale3: T.handle, p_take1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        vocab_size = T.int64()
        lm_head_q_weight3 = T.match_buffer(p_lm_head_q_weight3, (vocab_size, T.int64(384)), "uint32")
        lm_head_q_scale3 = T.match_buffer(p_lm_head_q_scale3, (vocab_size, T.int64(96)), "float16")
        batch_size = T.int64()
        take1 = T.match_buffer(p_take1, (T.int64(1), batch_size, T.int64(3072)), "float16")
        compute_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), batch_size, vocab_size))
        # with T.block("root"):
        compute = T.alloc_buffer((vocab_size, T.int64(3072)), "float16")
        dequantize_intermediate = T.alloc_buffer((vocab_size, T.int64(3072)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), batch_size, vocab_size), "float16")
        for i0, i1 in T.grid(vocab_size, T.int64(3072)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lm_head_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lm_head_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(vocab_size, T.int64(3072)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], lm_head_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * lm_head_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), batch_size, vocab_size, T.int64(3072)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(take1[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + take1[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for i0, i1, i2 in T.grid(T.int64(1), batch_size, vocab_size):
            with T.block("compute_1"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute_intermediate_intermediate[v_i0, v_i1, v_i2])
                compute_intermediate_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2])

    @T.prim_func(private=True)
    def fused_dequantize_take1(transformer_embd_q_weight: T.Buffer((32064, 384), "uint32"), transformer_embd_q_scale: T.Buffer((32064, 96), "float16"), p_input_ids: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int32()
        input_ids = T.match_buffer(p_input_ids, (seq_len,), "int32")
        T_take_intermediate = T.match_buffer(p_output0, (seq_len, 3072), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((32064, 3072), "float16")
        for i0, i1 in T.grid(32064, 3072):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(transformer_embd_q_weight[v_i0, v_i1 // 8])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(transformer_embd_q_weight[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 4)), T.uint32(15)))
        for ax0, ax1 in T.grid(seq_len, 3072):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(input_ids[v_ax0], compute[input_ids[v_ax0], v_ax1], transformer_embd_q_scale[input_ids[v_ax0], v_ax1 // 32])
                T.writes(T_take_intermediate[v_ax0, v_ax1])
                T_take_intermediate[v_ax0, v_ax1] = (compute[input_ids[v_ax0], v_ax1] - T.float16(7.0)) * transformer_embd_q_scale[input_ids[v_ax0], v_ax1 // 32]

    @T.prim_func(private=True)
    def fused_reshape10_reshape11(lv164: T.Buffer((T.int64(1), T.int64(32), T.int64(96)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(96)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv164[T.int64(0), (v_ax3 // T.int64(96) + v_ax2) % T.int64(32), v_ax3 % T.int64(96)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv164[T.int64(0), (v_ax3 // T.int64(96) + v_ax2) % T.int64(32), v_ax3 % T.int64(96)]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(3072) // T.int64(96), v_ax2 % T.int64(96)])
                T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(3072) // T.int64(96), v_ax2 % T.int64(96)]

    @T.prim_func(private=True)
    def fused_reshape8_reshape9(lv387: T.Buffer((T.int64(1), T.int64(1), T.int64(9216)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(96), T.int64(96)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(96), T.int64(96)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(96), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv387[T.int64(0), T.int64(0), (v_ax2 * T.int64(96) + v_ax3) % T.int64(9216)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv387[T.int64(0), T.int64(0), (v_ax2 * T.int64(96) + v_ax3) % T.int64(9216)]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(96), T.int64(96)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(96) + v_ax1) % T.int64(96), v_ax2 % T.int64(96)])
                T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(96) + v_ax1) % T.int64(96), v_ax2 % T.int64(96)]

    @T.prim_func
    def fused_rope_longrope_scaling(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, ext_factors: T.Buffer((48,), "float32")):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        qkv = T.match_buffer(var_qkv, (seq_len, 96, 96), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 32, 96), "float16")
        k = T.match_buffer(var_k, (seq_len, 32, 96), "float16")
        v = T.match_buffer(var_v, (seq_len, 32, 96), "float16")
        # with T.block("root"):
        for iters_0, iters_1, iters_2 in T.grid(seq_len, 96, 96):
            with T.block("llama_fused_rope"):
                s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2])
                T.reads(position_map[s], ext_factors[d % 48], qkv[s, h, d - 48:d - 48 + 97])
                T.writes(q[s, h, d], k[s, h - 32, d], v[s, h - 64, d])
                if h < 32:
                    freq = T.float32()
                    q[s, h, d] = T.if_then_else(d < 96, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(d < 48, qkv[s, h, d + 48] * T.float16(-1.0), qkv[s, h, d - 48]))), where={freq: T.Cast("float32", position_map[s]) / (ext_factors[d % 48] * T.pow(T.float32(10000.0), T.Cast("float32", d * 2 % 96) / T.float32(96.0)))}), qkv[s, h, d])
                else:
                    if h < 64:
                        freq = T.float32()
                        k[s, h - 32, d] = T.if_then_else(d < 96, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(d < 48, qkv[s, h, d + 48] * T.float16(-1.0), qkv[s, h, d - 48]))), where={freq: T.Cast("float32", position_map[s]) / (ext_factors[d % 48] * T.pow(T.float32(10000.0), T.Cast("float32", d * 2 % 96) / T.float32(96.0)))}), qkv[s, h, d])
                    else:
                        v[s, h - 64, d] = qkv[s, h, d]

    @T.prim_func(private=True)
    def fused_split1_silu1_multiply1(p_lv131: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv131 = T.match_buffer(p_lv131, (T.int64(1), seq_len, T.int64(16384)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(8192)), "float16")
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192)), "float16")
        compute = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(8192)), "float16")
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv131[v_ax0, v_ax1, v_ax2])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] = lv131[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv131[v_ax0, v_ax1, v_ax2 + T.int64(8192)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] = lv131[v_ax0, v_ax1, v_ax2 + T.int64(8192)]
        for i0, i1, i2 in T.grid(T.int64(1), seq_len, T.int64(8192)):
            with T.block("compute"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_split_sections_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute[v_i0, v_i1, v_i2])
                compute[v_i0, v_i1, v_i2] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1, v_i2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(8192)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2], T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] * T_multiply_intermediate[v_ax0, v_ax1, v_ax2]

    @T.prim_func(private=True)
    def fused_split2_silu2_multiply2(lv389: T.Buffer((T.int64(1), T.int64(1), T.int64(16384)), "float16"), T_multiply_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16")
        compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(8192)), "float16")
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv389[v_ax0, v_ax1, v_ax2])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] = lv389[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv389[v_ax0, v_ax1, v_ax2 + T.int64(8192)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] = lv389[v_ax0, v_ax1, v_ax2 + T.int64(8192)]
        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)):
            with T.block("compute"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_split_sections_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute[v_i0, v_i1, v_i2])
                compute[v_i0, v_i1, v_i2] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1, v_i2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(8192)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2], T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] * T_multiply_intermediate[v_ax0, v_ax1, v_ax2]

    @T.prim_func(private=True)
    def fused_split_silu_multiply(p_lv2: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv2 = T.match_buffer(p_lv2, (batch_size, T.int64(1), T.int64(16384)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(8192)), "float16")
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192)), "float16")
        compute = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192)), "float16")
        T_multiply_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(8192)), "float16")
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv2[v_ax0, v_ax1, v_ax2])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv2[v_ax0, v_ax1, v_ax2 + T.int64(8192)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2 + T.int64(8192)]
        for i0, i1, i2 in T.grid(batch_size, T.int64(1), T.int64(8192)):
            with T.block("compute"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_split_sections_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute[v_i0, v_i1, v_i2])
                compute[v_i0, v_i1, v_i2] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1, v_i2])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(8192)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2], T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] * T_multiply_intermediate[v_ax0, v_ax1, v_ax2]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("gather_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[indices[vb], vj], indices[vb])
                T.writes(dst[vb, vj])
                dst[vb, vj] = src[indices[vb], vj]

    @T.prim_func(private=True)
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int64(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0, ax1 in T.grid(out_batch, vocab_size):
            with T.block("T_get_index_from_sorted"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
                T.writes(output_index[v_ax0, 0])
                if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size:
                    if v_ax1 == T.int64(0):
                        output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
                    else:
                        if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
                            output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

    @T.prim_func(private=True)
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch, vocab_size):
            with T.block("T_get_renorm_prob"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0])
                T.writes(renorm_prob[v_ax0, 0])
                if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > 1):
                    renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
                else:
                    if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0]):
                        if v_ax1 + T.int64(1) == vocab_size:
                            renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
                        else:
                            if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0])):
                                renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]

    @T.prim_func(private=True)
    def index(var_rms_norm64: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16")):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm64 = T.match_buffer(var_rms_norm64, (T.int64(1), seq_len, T.int64(3072)), "float16")
        # with T.block("root"):
        for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("index"):
                v_i, v__, v_k = T.axis.remap("SSS", [i, _, k])
                T.reads(rms_norm64[v_i, seq_len - T.int64(1), v_k])
                T.writes(index[v_i, v__, v_k])
                index[v_i, v__, v_k] = rms_norm64[v_i, seq_len - T.int64(1), v_k]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(32, thread="threadIdx.y"):
                    for tx in T.thread_binding(24, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 32], S_other[bx, ty + by * 32], V[bx, ty + by * 32, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 32, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 32, tx * 4:tx * 4 + 4], S[bx, ty + by * 32])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 32]
                            s_other_val[0] = S_other[bx, ty + by * 32]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 32, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 32, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 32, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 32] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func
    def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        n, vocab_size = T.int64(), T.int64()
        prob = T.match_buffer(var_prob, (n, vocab_size))
        batch_size = T.int64()
        uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1))
        row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32")
        token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32")
        # with T.block("root"):
        aggregate = T.alloc_buffer((), scope="local")
        sample_id_local = T.alloc_buffer((), "int32", scope="local")
        step_iter = T.alloc_buffer((), "int32", scope="local")
        for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            row_idx: T.int32 = row_indices[bx, 0]
            for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    u: T.float32 = uniform_samples[bx, 0]
                    aggregate[()] = T.Cast("float32", 0)
                    step_iter[()] = 0
                    while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
                        with T.block(""):
                            T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()])
                            T.writes(sample_id_local[()], aggregate[()])
                            prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local")
                            cumsum = T.alloc_buffer((T.int64(512),), scope="shared")
                            greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            mask = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            valid = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            indices = T.alloc_buffer((T.int64(4),), "int32", scope="local")
                            step_aggregate = T.alloc_buffer((), scope="local")
                            for v in T.unroll(T.int64(4)):
                                idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v
                                prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0))
                                prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0.0), prob_local, T.Cast("float32", 0))
                                valid[v] = prob_local > T.float32(0.0) and idx < vocab_size
                            with T.block(""):
                                T.reads(prob_gt_threshold[T.int64(0):T.int64(4)])
                                T.writes(step_aggregate[()])
                                local_sum = T.alloc_buffer((), scope="local")
                                shared_buf = T.alloc_buffer((T.int64(128),), scope="shared")
                                idx: T.int64 = ty * T.int64(32) + tx
                                local_sum[()] = T.Cast("float32", 0)
                                for i in T.unroll(T.int64(4)):
                                    local_sum[()] = local_sum[()] + prob_gt_threshold[i]
                                shared_buf[idx] = local_sum[()]
                                for i in T.unroll(T.int64(7)):
                                    if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                        shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)]
                                step_aggregate[()] = shared_buf[0]
                            if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)):
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)]
                                for i in T.vectorized(T.int64(4)):
                                    cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i]
                                for i in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i):
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)]
                                for v in T.unroll(T.int64(4)):
                                    greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07)
                                with T.block(""):
                                    T.reads(greater_than_u[T.int64(0):T.int64(4)])
                                    T.writes(mask[T.int64(0):T.int64(4)])
                                    shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared")
                                    tx_idx: T.int64 = ty * T.int64(32) + tx
                                    shared_buf[tx_idx] = greater_than_u[T.int64(3)]
                                    mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0])
                                    for i in T.unroll(T.int64(1), T.int64(4)):
                                        mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)])
                                for v in T.unroll(T.int64(4)):
                                    mask[v] = mask[v] and valid[v]
                                    indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v)
                                with T.block(""):
                                    T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)])
                                    T.writes(sample_id_local[()])
                                    local_sum = T.alloc_buffer((), "int32", scope="local")
                                    shared_buf = T.alloc_buffer((T.int64(128),), "int32", scope="shared")
                                    idx: T.int64 = ty * T.int64(32) + tx
                                    local_sum[()] = T.Cast("int32", vocab_size - T.int64(1))
                                    for i in T.unroll(T.int64(4)):
                                        if mask[i]:
                                            local_sum[()] = T.min(local_sum[()], indices[i])
                                    shared_buf[idx] = local_sum[()]
                                    for i in T.unroll(T.int64(7)):
                                        if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                            shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)])
                                    sample_id_local[()] = shared_buf[0]
                            aggregate[()] = aggregate[()] + step_aggregate[()]
                        step_iter[()] = step_iter[()] + 1
                    if tx == T.int64(0) and ty == T.int64(0):
                        token_ids[bx, 0] = sample_id_local[()]

    @T.prim_func(private=True)
    def reshape(var_lv: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv = T.match_buffer(var_lv, (batch_size, T.int64(1), T.int64(9216)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(96), T.int64(96)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(96), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv[((v_ax2 * T.int64(96) + v_ax3) // T.int64(9216) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(96) + v_ax3) % T.int64(9216)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv[((v_ax2 * T.int64(96) + v_ax3) // T.int64(9216) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(96) + v_ax3) % T.int64(9216)]

    @T.prim_func(private=True)
    def reshape1(var_reshape384: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape384 = T.match_buffer(var_reshape384, (batch_size, T.int64(1), T.int64(96), T.int64(96)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(96), T.int64(96)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(96), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape384[((v_ax2 // T.int64(96) + v_ax1) // T.int64(96) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(96) + v_ax1) % T.int64(96), v_ax2 % T.int64(96)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape384[((v_ax2 // T.int64(96) + v_ax1) // T.int64(96) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(96) + v_ax1) % T.int64(96), v_ax2 % T.int64(96)]

    @T.prim_func(private=True)
    def reshape2(var_lv486: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv486 = T.match_buffer(var_lv486, (batch_size, T.int64(32), T.int64(96)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(32), T.int64(96)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(32), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv486[((v_ax3 // T.int64(96) + v_ax2) // T.int64(32) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(96) + v_ax2) % T.int64(32), v_ax3 % T.int64(96)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv486[((v_ax3 // T.int64(96) + v_ax2) // T.int64(32) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(96) + v_ax2) % T.int64(32), v_ax3 % T.int64(96)]

    @T.prim_func(private=True)
    def reshape3(var_reshape386: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape386 = T.match_buffer(var_reshape386, (batch_size, T.int64(1), T.int64(32), T.int64(96)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(3072)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(3072)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape386[(v_ax2 // T.int64(3072) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(3072) // T.int64(96), v_ax2 % T.int64(96)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape386[(v_ax2 // T.int64(3072) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(3072) // T.int64(96), v_ax2 % T.int64(96)]

    @T.prim_func(private=True)
    def reshape4(var_lv129: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv129 = T.match_buffer(var_lv129, (T.int64(1), seq_len, T.int64(9216)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(96), T.int64(96)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(96), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv129[T.int64(0), ((v_ax2 * T.int64(96) + v_ax3) // T.int64(9216) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(96) + v_ax3) % T.int64(9216)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv129[T.int64(0), ((v_ax2 * T.int64(96) + v_ax3) // T.int64(9216) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(96) + v_ax3) % T.int64(9216)]

    @T.prim_func(private=True)
    def reshape5(var_reshape256: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape256 = T.match_buffer(var_reshape256, (T.int64(1), seq_len, T.int64(96), T.int64(96)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(96), T.int64(96)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(seq_len, T.int64(96), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape256[T.int64(0), ((v_ax2 // T.int64(96) + v_ax1) // T.int64(96) + v_ax0) % seq_len, (v_ax2 // T.int64(96) + v_ax1) % T.int64(96), v_ax2 % T.int64(96)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape256[T.int64(0), ((v_ax2 // T.int64(96) + v_ax1) // T.int64(96) + v_ax0) % seq_len, (v_ax2 // T.int64(96) + v_ax1) % T.int64(96), v_ax2 % T.int64(96)]

    @T.prim_func(private=True)
    def reshape6(var_lv325: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv325 = T.match_buffer(var_lv325, (seq_len, T.int64(32), T.int64(96)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(32), T.int64(96)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(32), T.int64(96)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv325[((v_ax3 // T.int64(96) + v_ax2) // T.int64(32) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(96) + v_ax2) % T.int64(32), v_ax3 % T.int64(96)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv325[((v_ax3 // T.int64(96) + v_ax2) // T.int64(32) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(96) + v_ax2) % T.int64(32), v_ax3 % T.int64(96)]

    @T.prim_func(private=True)
    def reshape7(var_reshape258: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape258 = T.match_buffer(var_reshape258, (T.int64(1), seq_len, T.int64(32), T.int64(96)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(3072)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(3072)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape258[T.int64(0), (v_ax2 // T.int64(3072) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(3072) // T.int64(96), v_ax2 % T.int64(96)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape258[T.int64(0), (v_ax2 // T.int64(3072) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(3072) // T.int64(96), v_ax2 % T.int64(96)]

    @T.prim_func(private=True)
    def rms_norm(var_input_embeds: T.handle, transformer_h_0_ln_weight4: T.Buffer((T.int64(3072),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (batch_size, T.int64(1), T.int64(3072)), "float16")
        T_cast = T.match_buffer(var_T_cast, (batch_size, T.int64(1), T.int64(3072)), "float16")
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(3072)))
        T_multiply = T.alloc_buffer((batch_size, T.int64(1), T.int64(3072)))
        T_multiply_red = T.alloc_buffer((batch_size, T.int64(1)))
        rsqrt = T.alloc_buffer((batch_size, T.int64(1)))
        T_cast_2 = T.alloc_buffer((T.int64(3072),))
        T_rms_norm = T.alloc_buffer((batch_size, T.int64(1), T.int64(3072)))
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(3072)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embeds[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(3072)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(batch_size, T.int64(1), T.int64(3072)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(batch_size, T.int64(1)):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00032552083333333332) + T.float32(1.0000000000000001e-05))
        for ax0 in range(T.int64(3072)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(3072), ax0)
                T.reads(transformer_h_0_ln_weight4[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", transformer_h_0_ln_weight4[v_ax0])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(3072)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(3072)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func(private=True)
    def rms_norm1(var_input_embeds: T.handle, transformer_h_0_ln_weight3: T.Buffer((T.int64(3072),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (T.int64(1), seq_len, T.int64(3072)), "float16")
        T_cast = T.match_buffer(var_T_cast, (T.int64(1), seq_len, T.int64(3072)), "float16")
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(3072)))
        T_multiply = T.alloc_buffer((T.int64(1), seq_len, T.int64(3072)))
        T_multiply_red = T.alloc_buffer((T.int64(1), seq_len))
        rsqrt = T.alloc_buffer((T.int64(1), seq_len))
        T_cast_2 = T.alloc_buffer((T.int64(3072),))
        T_rms_norm = T.alloc_buffer((T.int64(1), seq_len, T.int64(3072)))
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(3072)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embeds[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(3072)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(T.int64(1), seq_len, T.int64(3072)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(T.int64(1), seq_len):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00032552083333333332) + T.float32(1.0000000000000001e-05))
        for ax0 in range(T.int64(3072)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(3072), ax0)
                T.reads(transformer_h_0_ln_weight3[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", transformer_h_0_ln_weight3[v_ax0])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(3072)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(3072)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func(private=True)
    def rms_norm2(input_embed: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16"), transformer_h_0_ln_weight2: T.Buffer((T.int64(3072),), "float16"), T_cast: T.Buffer((T.int64(1), T.int64(1), T.int64(3072)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(3072)))
        T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(3072)))
        T_multiply_red = T.alloc_buffer((T.int64(1), T.int64(1)))
        rsqrt = T.alloc_buffer((T.int64(1), T.int64(1)))
        T_cast_2 = T.alloc_buffer((T.int64(3072),))
        T_rms_norm = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(3072)))
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embed[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embed[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00032552083333333332) + T.float32(1.0000000000000001e-05))
        for ax0 in range(T.int64(3072)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(3072), ax0)
                T.reads(transformer_h_0_ln_weight2[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", transformer_h_0_ln_weight2[v_ax0])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(3072)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for i in range(num_positions + num_samples):
            with T.block("block"):
                vi = T.axis.spatial(num_positions + num_samples, i)
                T.reads(top_prob_offsets[vi], sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], unsorted_probs[T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]):T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + (T.max(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + 1 - T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions])), T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]):T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + (T.max(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]))], sample_indices[vi - num_positions], sampling_results[vi - num_positions])
                T.writes(top_prob_indices[vi], top_prob_probs[vi], sampled_values[vi - num_positions])
                if vi < num_positions:
                    row: T.int32 = top_prob_offsets[vi] // vocab_size
                    col: T.int32 = top_prob_offsets[vi] % vocab_size
                    top_prob_indices[vi] = sorted_indices[row, col]
                    top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]]
                else:
                    vj: T.int32 = vi - num_positions
                    sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("scatter_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[vb, vj], indices[vb])
                T.writes(dst[indices[vb], vj])
                dst[indices[vb], vj] = src[vb, vj]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * T.int64(4096) + v2])
                            if v1 * T.int64(4096) + v2 < vocab_size:
                                softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func(private=True)
    def take(var_rms_norm194: T.handle, var_logit_positions: T.handle, var_T_take: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm194 = T.match_buffer(var_rms_norm194, (T.int64(1), seq_len, T.int64(3072)), "float16")
        batch_size = T.int64()
        logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32")
        T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(3072)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), batch_size, T.int64(3072)):
            with T.block("T_take"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rms_norm194[v_ax0, logit_positions[v_ax1], v_ax2], logit_positions[v_ax1])
                T.writes(T_take[v_ax0, v_ax1, v_ax2])
                T_take[v_ax0, v_ax1, v_ax2] = rms_norm194[v_ax0, logit_positions[v_ax1], v_ax2]

    @T.prim_func(private=True)
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int64(is_size_var=True), T.int64(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for i, j in T.grid(batch_size_1, vocab_size_1):
            with T.block("take_sorted_probs"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(probs[v_i, lv1[v_i, v_j]], lv1[v_i, v_j])
                T.writes(take_sorted_probs[v_i, v_j])
                take_sorted_probs[v_i, v_j] = probs[v_i, lv1[v_i, v_j]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int64(), T.int64(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 32, page_size, 96), "float16", offset_factor=1)
        seqlen = T.int64(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (32, seqlen, 32, 96), "float16")
        v_data = T.match_buffer(var_v_data, (32, seqlen, 32, 96), "float16")
        # with T.block("root"):
        for p, h, d in T.grid(seqlen, 32, 96):
            with T.block("copy0"):
                vp, vh, vd = T.axis.remap("SSS", [p, h, d])
                T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd])
                T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                position: T.int32 = position_map[vp]
                k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd]
                v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        num_pages = T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 32, 16, 96), "float16", offset_factor=1)
        ntoken = T.int64(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 32, 96), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 32, 96), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos, h, f in T.grid(ntoken, 32, 96):
            if position_map[global_pos] != -1:
                with T.block("k_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                with T.block("v_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func(private=True)
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func(private=True)
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 32, 96), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 32, 16, 96), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 32, 96), "float16")
        lse = T.match_buffer(var_lse, (total_len, 32))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(32, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 96), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 96), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 6):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(6):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 96)
                                                        j = T.axis.spatial(96, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 96)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.float32(1.1902380714238083) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.float32(1.1902380714238083) * T.Cast("float32", T.if_then_else(j < 48, q[cur_L, cur_H_qo, j + 48] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 48]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 96) / T.float32(96.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(6):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 96)
                                                            j = T.axis.spatial(96, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 96)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(6):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 96)
                                                            j = T.axis.spatial(96, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 96)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:96], K_smem[0:32, 0:96])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(12, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(96, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.14724444602590309)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:96])
                                            T.writes(O_local[0:32, 0:96])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 6):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 6 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 6):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 6):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(96, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 6 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((2048, 3072), dtype="float16"):
        R.func_attr({"relax.memory_plan_dynamic_func_output": True})
        gv: R.Tensor((2048, 3072), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 3072]), R.dtype("float16"), R.prim_value(0), R.str("global"))
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv: R.Tensor((8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12,), dtype="uint8") = R.builtin.alloc_tensor(R.shape([8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12]), R.dtype("uint8"), R.prim_value(0), R.str("global"))
            lv1 = R.call_tir(cls.argsort_thrust, (probs, lv), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="int32"))
            lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1
            R.output(gv1)
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 3072), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32064, 384), dtype="uint32"), R.Tensor((32064, 96), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor(("vocab_size", 384), dtype="uint32"), R.Tensor(("vocab_size", 96), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), R.Object):
        batch_size = T.int64()
        vocab_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            transformer_h_0_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[2]
            transformer_h_0_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[3]
            transformer_h_0_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[4]
            transformer_h_0_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[5]
            transformer_h_0_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[6]
            transformer_h_0_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[7]
            transformer_h_0_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[8]
            transformer_h_0_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[9]
            transformer_h_0_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[10]
            transformer_h_0_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[11]
            transformer_h_1_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[12]
            transformer_h_1_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[13]
            transformer_h_1_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[14]
            transformer_h_1_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[15]
            transformer_h_1_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[16]
            transformer_h_1_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[17]
            transformer_h_1_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[18]
            transformer_h_1_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[19]
            transformer_h_1_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[20]
            transformer_h_1_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[21]
            transformer_h_2_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[22]
            transformer_h_2_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[23]
            transformer_h_2_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[24]
            transformer_h_2_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[25]
            transformer_h_2_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[26]
            transformer_h_2_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[27]
            transformer_h_2_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[28]
            transformer_h_2_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[29]
            transformer_h_2_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[30]
            transformer_h_2_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[31]
            transformer_h_3_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[32]
            transformer_h_3_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[33]
            transformer_h_3_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[34]
            transformer_h_3_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[35]
            transformer_h_3_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[36]
            transformer_h_3_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[37]
            transformer_h_3_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[38]
            transformer_h_3_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[39]
            transformer_h_3_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[40]
            transformer_h_3_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[41]
            transformer_h_4_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[42]
            transformer_h_4_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[43]
            transformer_h_4_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[44]
            transformer_h_4_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[45]
            transformer_h_4_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[46]
            transformer_h_4_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[47]
            transformer_h_4_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[48]
            transformer_h_4_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[49]
            transformer_h_4_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[50]
            transformer_h_4_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[51]
            transformer_h_5_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[52]
            transformer_h_5_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[53]
            transformer_h_5_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[54]
            transformer_h_5_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[55]
            transformer_h_5_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[56]
            transformer_h_5_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[57]
            transformer_h_5_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[58]
            transformer_h_5_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[59]
            transformer_h_5_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[60]
            transformer_h_5_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[61]
            transformer_h_6_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[62]
            transformer_h_6_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[63]
            transformer_h_6_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[64]
            transformer_h_6_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[65]
            transformer_h_6_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[66]
            transformer_h_6_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[67]
            transformer_h_6_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[68]
            transformer_h_6_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[69]
            transformer_h_6_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[70]
            transformer_h_6_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[71]
            transformer_h_7_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[72]
            transformer_h_7_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[73]
            transformer_h_7_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[74]
            transformer_h_7_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[75]
            transformer_h_7_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[76]
            transformer_h_7_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[77]
            transformer_h_7_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[78]
            transformer_h_7_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[79]
            transformer_h_7_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[80]
            transformer_h_7_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[81]
            transformer_h_8_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[82]
            transformer_h_8_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[83]
            transformer_h_8_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[84]
            transformer_h_8_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[85]
            transformer_h_8_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[86]
            transformer_h_8_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[87]
            transformer_h_8_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[88]
            transformer_h_8_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[89]
            transformer_h_8_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[90]
            transformer_h_8_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[91]
            transformer_h_9_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[92]
            transformer_h_9_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[93]
            transformer_h_9_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[94]
            transformer_h_9_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[95]
            transformer_h_9_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[96]
            transformer_h_9_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[97]
            transformer_h_9_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[98]
            transformer_h_9_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[99]
            transformer_h_9_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[100]
            transformer_h_9_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[101]
            transformer_h_10_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[102]
            transformer_h_10_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[103]
            transformer_h_10_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[104]
            transformer_h_10_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[105]
            transformer_h_10_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[106]
            transformer_h_10_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[107]
            transformer_h_10_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[108]
            transformer_h_10_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[109]
            transformer_h_10_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[110]
            transformer_h_10_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[111]
            transformer_h_11_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[112]
            transformer_h_11_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[113]
            transformer_h_11_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[114]
            transformer_h_11_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[115]
            transformer_h_11_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[116]
            transformer_h_11_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[117]
            transformer_h_11_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[118]
            transformer_h_11_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[119]
            transformer_h_11_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[120]
            transformer_h_11_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[121]
            transformer_h_12_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[122]
            transformer_h_12_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[123]
            transformer_h_12_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[124]
            transformer_h_12_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[125]
            transformer_h_12_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[126]
            transformer_h_12_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[127]
            transformer_h_12_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[128]
            transformer_h_12_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[129]
            transformer_h_12_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[130]
            transformer_h_12_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[131]
            transformer_h_13_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[132]
            transformer_h_13_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[133]
            transformer_h_13_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[134]
            transformer_h_13_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[135]
            transformer_h_13_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[136]
            transformer_h_13_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[137]
            transformer_h_13_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[138]
            transformer_h_13_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[139]
            transformer_h_13_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[140]
            transformer_h_13_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[141]
            transformer_h_14_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[142]
            transformer_h_14_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[143]
            transformer_h_14_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[144]
            transformer_h_14_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[145]
            transformer_h_14_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[146]
            transformer_h_14_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[147]
            transformer_h_14_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[148]
            transformer_h_14_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[149]
            transformer_h_14_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[150]
            transformer_h_14_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[151]
            transformer_h_15_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[152]
            transformer_h_15_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[153]
            transformer_h_15_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[154]
            transformer_h_15_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[155]
            transformer_h_15_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[156]
            transformer_h_15_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[157]
            transformer_h_15_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[158]
            transformer_h_15_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[159]
            transformer_h_15_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[160]
            transformer_h_15_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[161]
            transformer_h_16_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[162]
            transformer_h_16_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[163]
            transformer_h_16_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[164]
            transformer_h_16_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[165]
            transformer_h_16_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[166]
            transformer_h_16_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[167]
            transformer_h_16_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[168]
            transformer_h_16_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[169]
            transformer_h_16_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[170]
            transformer_h_16_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[171]
            transformer_h_17_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[172]
            transformer_h_17_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[173]
            transformer_h_17_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[174]
            transformer_h_17_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[175]
            transformer_h_17_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[176]
            transformer_h_17_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[177]
            transformer_h_17_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[178]
            transformer_h_17_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[179]
            transformer_h_17_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[180]
            transformer_h_17_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[181]
            transformer_h_18_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[182]
            transformer_h_18_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[183]
            transformer_h_18_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[184]
            transformer_h_18_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[185]
            transformer_h_18_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[186]
            transformer_h_18_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[187]
            transformer_h_18_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[188]
            transformer_h_18_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[189]
            transformer_h_18_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[190]
            transformer_h_18_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[191]
            transformer_h_19_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[192]
            transformer_h_19_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[193]
            transformer_h_19_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[194]
            transformer_h_19_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[195]
            transformer_h_19_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[196]
            transformer_h_19_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[197]
            transformer_h_19_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[198]
            transformer_h_19_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[199]
            transformer_h_19_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[200]
            transformer_h_19_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[201]
            transformer_h_20_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[202]
            transformer_h_20_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[203]
            transformer_h_20_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[204]
            transformer_h_20_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[205]
            transformer_h_20_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[206]
            transformer_h_20_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[207]
            transformer_h_20_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[208]
            transformer_h_20_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[209]
            transformer_h_20_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[210]
            transformer_h_20_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[211]
            transformer_h_21_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[212]
            transformer_h_21_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[213]
            transformer_h_21_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[214]
            transformer_h_21_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[215]
            transformer_h_21_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[216]
            transformer_h_21_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[217]
            transformer_h_21_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[218]
            transformer_h_21_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[219]
            transformer_h_21_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[220]
            transformer_h_21_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[221]
            transformer_h_22_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[222]
            transformer_h_22_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[223]
            transformer_h_22_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[224]
            transformer_h_22_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[225]
            transformer_h_22_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[226]
            transformer_h_22_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[227]
            transformer_h_22_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[228]
            transformer_h_22_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[229]
            transformer_h_22_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[230]
            transformer_h_22_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[231]
            transformer_h_23_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[232]
            transformer_h_23_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[233]
            transformer_h_23_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[234]
            transformer_h_23_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[235]
            transformer_h_23_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[236]
            transformer_h_23_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[237]
            transformer_h_23_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[238]
            transformer_h_23_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[239]
            transformer_h_23_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[240]
            transformer_h_23_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[241]
            transformer_h_24_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[242]
            transformer_h_24_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[243]
            transformer_h_24_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[244]
            transformer_h_24_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[245]
            transformer_h_24_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[246]
            transformer_h_24_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[247]
            transformer_h_24_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[248]
            transformer_h_24_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[249]
            transformer_h_24_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[250]
            transformer_h_24_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[251]
            transformer_h_25_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[252]
            transformer_h_25_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[253]
            transformer_h_25_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[254]
            transformer_h_25_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[255]
            transformer_h_25_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[256]
            transformer_h_25_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[257]
            transformer_h_25_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[258]
            transformer_h_25_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[259]
            transformer_h_25_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[260]
            transformer_h_25_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[261]
            transformer_h_26_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[262]
            transformer_h_26_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[263]
            transformer_h_26_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[264]
            transformer_h_26_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[265]
            transformer_h_26_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[266]
            transformer_h_26_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[267]
            transformer_h_26_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[268]
            transformer_h_26_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[269]
            transformer_h_26_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[270]
            transformer_h_26_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[271]
            transformer_h_27_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[272]
            transformer_h_27_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[273]
            transformer_h_27_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[274]
            transformer_h_27_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[275]
            transformer_h_27_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[276]
            transformer_h_27_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[277]
            transformer_h_27_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[278]
            transformer_h_27_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[279]
            transformer_h_27_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[280]
            transformer_h_27_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[281]
            transformer_h_28_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[282]
            transformer_h_28_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[283]
            transformer_h_28_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[284]
            transformer_h_28_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[285]
            transformer_h_28_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[286]
            transformer_h_28_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[287]
            transformer_h_28_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[288]
            transformer_h_28_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[289]
            transformer_h_28_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[290]
            transformer_h_28_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[291]
            transformer_h_29_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[292]
            transformer_h_29_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[293]
            transformer_h_29_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[294]
            transformer_h_29_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[295]
            transformer_h_29_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[296]
            transformer_h_29_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[297]
            transformer_h_29_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[298]
            transformer_h_29_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[299]
            transformer_h_29_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[300]
            transformer_h_29_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[301]
            transformer_h_30_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[302]
            transformer_h_30_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[303]
            transformer_h_30_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[304]
            transformer_h_30_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[305]
            transformer_h_30_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[306]
            transformer_h_30_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[307]
            transformer_h_30_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[308]
            transformer_h_30_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[309]
            transformer_h_30_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[310]
            transformer_h_30_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[311]
            transformer_h_31_ln_weight4: R.Tensor((3072,), dtype="float16") = packed_params[312]
            transformer_h_31_mixer_qkv_proj_q_weight4: R.Tensor((9216, 384), dtype="uint32") = packed_params[313]
            transformer_h_31_mixer_qkv_proj_q_scale4: R.Tensor((9216, 96), dtype="float16") = packed_params[314]
            transformer_h_31_mixer_out_proj_q_weight4: R.Tensor((3072, 384), dtype="uint32") = packed_params[315]
            transformer_h_31_mixer_out_proj_q_scale4: R.Tensor((3072, 96), dtype="float16") = packed_params[316]
            transformer_h_31_mlp_gate_up_proj_q_weight4: R.Tensor((16384, 384), dtype="uint32") = packed_params[317]
            transformer_h_31_mlp_gate_up_proj_q_scale4: R.Tensor((16384, 96), dtype="float16") = packed_params[318]
            transformer_h_31_mlp_down_proj_q_weight4: R.Tensor((3072, 1024), dtype="uint32") = packed_params[319]
            transformer_h_31_mlp_down_proj_q_scale4: R.Tensor((3072, 256), dtype="float16") = packed_params[320]
            transformer_h_31_post_attention_layernorm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[321]
            transformer_norm_weight4: R.Tensor((3072,), dtype="float16") = packed_params[322]
            lm_head_q_weight4: R.Tensor((vocab_size, 384), dtype="uint32") = packed_params[323]
            lm_head_q_scale4: R.Tensor((vocab_size, 96), dtype="float16") = packed_params[324]
            rms_norm195 = R.call_tir(cls.rms_norm, (input_embeds, transformer_h_0_ln_weight4), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_0_mixer_qkv_proj_q_weight4, transformer_h_0_mixer_qkv_proj_q_scale4, rms_norm195), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape384 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape385 = R.call_tir(cls.reshape1, (reshape384,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv486 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape385), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape386 = R.call_tir(cls.reshape2, (lv486,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape387 = R.call_tir(cls.reshape3, (reshape386,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_0_mixer_out_proj_q_weight4, transformer_h_0_mixer_out_proj_q_scale4, reshape387), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv_1 = R.call_tir(cls.fuse_add_norm_decode, (lv1, input_embeds, transformer_h_0_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv1_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv_1[1]
            rms_norm196: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv_1[0]
            lv2 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_0_mlp_gate_up_proj_q_weight4, transformer_h_0_mlp_gate_up_proj_q_scale4, rms_norm196), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv_2 = R.call_tir(cls.fused_split_silu_multiply, (lv2,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv3 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_0_mlp_down_proj_q_weight4, transformer_h_0_mlp_down_proj_q_scale4, lv_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv2_1 = R.call_tir(cls.fuse_add_norm_decode, (lv3, lv1_1, transformer_h_1_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv3_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv2_1[1]
            rms_norm197: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv2_1[0]
            lv4 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_1_mixer_qkv_proj_q_weight4, transformer_h_1_mixer_qkv_proj_q_scale4, rms_norm197), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape388 = R.call_tir(cls.reshape, (lv4,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape389 = R.call_tir(cls.reshape1, (reshape388,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv491 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape389), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape390 = R.call_tir(cls.reshape2, (lv491,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape391 = R.call_tir(cls.reshape3, (reshape390,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv5 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_1_mixer_out_proj_q_weight4, transformer_h_1_mixer_out_proj_q_scale4, reshape391), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv4_1 = R.call_tir(cls.fuse_add_norm_decode, (lv5, lv3_1, transformer_h_1_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv5_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv4_1[1]
            rms_norm198: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv4_1[0]
            lv6 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_1_mlp_gate_up_proj_q_weight4, transformer_h_1_mlp_gate_up_proj_q_scale4, rms_norm198), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv1_2 = R.call_tir(cls.fused_split_silu_multiply, (lv6,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv7 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_1_mlp_down_proj_q_weight4, transformer_h_1_mlp_down_proj_q_scale4, lv1_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv6_1 = R.call_tir(cls.fuse_add_norm_decode, (lv7, lv5_1, transformer_h_2_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv7_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv6_1[1]
            rms_norm199: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv6_1[0]
            lv8 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_2_mixer_qkv_proj_q_weight4, transformer_h_2_mixer_qkv_proj_q_scale4, rms_norm199), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape392 = R.call_tir(cls.reshape, (lv8,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape393 = R.call_tir(cls.reshape1, (reshape392,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv496 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape393), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape394 = R.call_tir(cls.reshape2, (lv496,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape395 = R.call_tir(cls.reshape3, (reshape394,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv9 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_2_mixer_out_proj_q_weight4, transformer_h_2_mixer_out_proj_q_scale4, reshape395), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv8_1 = R.call_tir(cls.fuse_add_norm_decode, (lv9, lv7_1, transformer_h_2_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv9_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv8_1[1]
            rms_norm200: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv8_1[0]
            lv10 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_2_mlp_gate_up_proj_q_weight4, transformer_h_2_mlp_gate_up_proj_q_scale4, rms_norm200), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv2_2 = R.call_tir(cls.fused_split_silu_multiply, (lv10,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv11 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_2_mlp_down_proj_q_weight4, transformer_h_2_mlp_down_proj_q_scale4, lv2_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv10_1 = R.call_tir(cls.fuse_add_norm_decode, (lv11, lv9_1, transformer_h_3_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv11_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv10_1[1]
            rms_norm201: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv10_1[0]
            lv12 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_3_mixer_qkv_proj_q_weight4, transformer_h_3_mixer_qkv_proj_q_scale4, rms_norm201), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape396 = R.call_tir(cls.reshape, (lv12,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape397 = R.call_tir(cls.reshape1, (reshape396,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv501 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape397), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape398 = R.call_tir(cls.reshape2, (lv501,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape399 = R.call_tir(cls.reshape3, (reshape398,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv13 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_3_mixer_out_proj_q_weight4, transformer_h_3_mixer_out_proj_q_scale4, reshape399), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv12_1 = R.call_tir(cls.fuse_add_norm_decode, (lv13, lv11_1, transformer_h_3_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv13_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv12_1[1]
            rms_norm202: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv12_1[0]
            lv14 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_3_mlp_gate_up_proj_q_weight4, transformer_h_3_mlp_gate_up_proj_q_scale4, rms_norm202), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv3_2 = R.call_tir(cls.fused_split_silu_multiply, (lv14,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv15 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_3_mlp_down_proj_q_weight4, transformer_h_3_mlp_down_proj_q_scale4, lv3_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv14_1 = R.call_tir(cls.fuse_add_norm_decode, (lv15, lv13_1, transformer_h_4_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv15_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv14_1[1]
            rms_norm203: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv14_1[0]
            lv16 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_4_mixer_qkv_proj_q_weight4, transformer_h_4_mixer_qkv_proj_q_scale4, rms_norm203), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape400 = R.call_tir(cls.reshape, (lv16,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape401 = R.call_tir(cls.reshape1, (reshape400,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv506 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape401), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape402 = R.call_tir(cls.reshape2, (lv506,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape403 = R.call_tir(cls.reshape3, (reshape402,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv17 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_4_mixer_out_proj_q_weight4, transformer_h_4_mixer_out_proj_q_scale4, reshape403), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv16_1 = R.call_tir(cls.fuse_add_norm_decode, (lv17, lv15_1, transformer_h_4_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv17_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv16_1[1]
            rms_norm204: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv16_1[0]
            lv18 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_4_mlp_gate_up_proj_q_weight4, transformer_h_4_mlp_gate_up_proj_q_scale4, rms_norm204), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv4_2 = R.call_tir(cls.fused_split_silu_multiply, (lv18,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv19 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_4_mlp_down_proj_q_weight4, transformer_h_4_mlp_down_proj_q_scale4, lv4_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv18_1 = R.call_tir(cls.fuse_add_norm_decode, (lv19, lv17_1, transformer_h_5_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv19_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv18_1[1]
            rms_norm205: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv18_1[0]
            lv20 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_5_mixer_qkv_proj_q_weight4, transformer_h_5_mixer_qkv_proj_q_scale4, rms_norm205), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape404 = R.call_tir(cls.reshape, (lv20,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape405 = R.call_tir(cls.reshape1, (reshape404,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv511 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape405), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape406 = R.call_tir(cls.reshape2, (lv511,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape407 = R.call_tir(cls.reshape3, (reshape406,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv21 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_5_mixer_out_proj_q_weight4, transformer_h_5_mixer_out_proj_q_scale4, reshape407), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv20_1 = R.call_tir(cls.fuse_add_norm_decode, (lv21, lv19_1, transformer_h_5_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv21_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv20_1[1]
            rms_norm206: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv20_1[0]
            lv22 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_5_mlp_gate_up_proj_q_weight4, transformer_h_5_mlp_gate_up_proj_q_scale4, rms_norm206), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv5_2 = R.call_tir(cls.fused_split_silu_multiply, (lv22,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv23 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_5_mlp_down_proj_q_weight4, transformer_h_5_mlp_down_proj_q_scale4, lv5_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv22_1 = R.call_tir(cls.fuse_add_norm_decode, (lv23, lv21_1, transformer_h_6_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv23_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv22_1[1]
            rms_norm207: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv22_1[0]
            lv24 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_6_mixer_qkv_proj_q_weight4, transformer_h_6_mixer_qkv_proj_q_scale4, rms_norm207), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape408 = R.call_tir(cls.reshape, (lv24,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape409 = R.call_tir(cls.reshape1, (reshape408,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv516 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape409), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape410 = R.call_tir(cls.reshape2, (lv516,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape411 = R.call_tir(cls.reshape3, (reshape410,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv25 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_6_mixer_out_proj_q_weight4, transformer_h_6_mixer_out_proj_q_scale4, reshape411), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv24_1 = R.call_tir(cls.fuse_add_norm_decode, (lv25, lv23_1, transformer_h_6_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv25_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv24_1[1]
            rms_norm208: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv24_1[0]
            lv26 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_6_mlp_gate_up_proj_q_weight4, transformer_h_6_mlp_gate_up_proj_q_scale4, rms_norm208), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv6_2 = R.call_tir(cls.fused_split_silu_multiply, (lv26,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv27 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_6_mlp_down_proj_q_weight4, transformer_h_6_mlp_down_proj_q_scale4, lv6_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv26_1 = R.call_tir(cls.fuse_add_norm_decode, (lv27, lv25_1, transformer_h_7_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv27_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv26_1[1]
            rms_norm209: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv26_1[0]
            lv28 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_7_mixer_qkv_proj_q_weight4, transformer_h_7_mixer_qkv_proj_q_scale4, rms_norm209), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape412 = R.call_tir(cls.reshape, (lv28,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape413 = R.call_tir(cls.reshape1, (reshape412,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv521 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape413), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape414 = R.call_tir(cls.reshape2, (lv521,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape415 = R.call_tir(cls.reshape3, (reshape414,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv29 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_7_mixer_out_proj_q_weight4, transformer_h_7_mixer_out_proj_q_scale4, reshape415), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv28_1 = R.call_tir(cls.fuse_add_norm_decode, (lv29, lv27_1, transformer_h_7_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv29_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv28_1[1]
            rms_norm210: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv28_1[0]
            lv30 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_7_mlp_gate_up_proj_q_weight4, transformer_h_7_mlp_gate_up_proj_q_scale4, rms_norm210), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv7_2 = R.call_tir(cls.fused_split_silu_multiply, (lv30,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv31 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_7_mlp_down_proj_q_weight4, transformer_h_7_mlp_down_proj_q_scale4, lv7_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv30_1 = R.call_tir(cls.fuse_add_norm_decode, (lv31, lv29_1, transformer_h_8_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv31_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv30_1[1]
            rms_norm211: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv30_1[0]
            lv32 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_8_mixer_qkv_proj_q_weight4, transformer_h_8_mixer_qkv_proj_q_scale4, rms_norm211), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape416 = R.call_tir(cls.reshape, (lv32,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape417 = R.call_tir(cls.reshape1, (reshape416,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv526 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape417), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape418 = R.call_tir(cls.reshape2, (lv526,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape419 = R.call_tir(cls.reshape3, (reshape418,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv33 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_8_mixer_out_proj_q_weight4, transformer_h_8_mixer_out_proj_q_scale4, reshape419), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv32_1 = R.call_tir(cls.fuse_add_norm_decode, (lv33, lv31_1, transformer_h_8_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv33_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv32_1[1]
            rms_norm212: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv32_1[0]
            lv34 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_8_mlp_gate_up_proj_q_weight4, transformer_h_8_mlp_gate_up_proj_q_scale4, rms_norm212), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv8_2 = R.call_tir(cls.fused_split_silu_multiply, (lv34,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv35 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_8_mlp_down_proj_q_weight4, transformer_h_8_mlp_down_proj_q_scale4, lv8_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv34_1 = R.call_tir(cls.fuse_add_norm_decode, (lv35, lv33_1, transformer_h_9_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv35_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv34_1[1]
            rms_norm213: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv34_1[0]
            lv36 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_9_mixer_qkv_proj_q_weight4, transformer_h_9_mixer_qkv_proj_q_scale4, rms_norm213), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape420 = R.call_tir(cls.reshape, (lv36,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape421 = R.call_tir(cls.reshape1, (reshape420,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv531 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape421), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape422 = R.call_tir(cls.reshape2, (lv531,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape423 = R.call_tir(cls.reshape3, (reshape422,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv37 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_9_mixer_out_proj_q_weight4, transformer_h_9_mixer_out_proj_q_scale4, reshape423), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv36_1 = R.call_tir(cls.fuse_add_norm_decode, (lv37, lv35_1, transformer_h_9_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv37_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv36_1[1]
            rms_norm214: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv36_1[0]
            lv38 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_9_mlp_gate_up_proj_q_weight4, transformer_h_9_mlp_gate_up_proj_q_scale4, rms_norm214), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv9_2 = R.call_tir(cls.fused_split_silu_multiply, (lv38,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv39 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_9_mlp_down_proj_q_weight4, transformer_h_9_mlp_down_proj_q_scale4, lv9_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv38_1 = R.call_tir(cls.fuse_add_norm_decode, (lv39, lv37_1, transformer_h_10_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv39_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv38_1[1]
            rms_norm215: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv38_1[0]
            lv40 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_10_mixer_qkv_proj_q_weight4, transformer_h_10_mixer_qkv_proj_q_scale4, rms_norm215), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape424 = R.call_tir(cls.reshape, (lv40,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape425 = R.call_tir(cls.reshape1, (reshape424,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv536 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape425), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape426 = R.call_tir(cls.reshape2, (lv536,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape427 = R.call_tir(cls.reshape3, (reshape426,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv41 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_10_mixer_out_proj_q_weight4, transformer_h_10_mixer_out_proj_q_scale4, reshape427), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv40_1 = R.call_tir(cls.fuse_add_norm_decode, (lv41, lv39_1, transformer_h_10_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv41_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv40_1[1]
            rms_norm216: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv40_1[0]
            lv42 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_10_mlp_gate_up_proj_q_weight4, transformer_h_10_mlp_gate_up_proj_q_scale4, rms_norm216), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv10_2 = R.call_tir(cls.fused_split_silu_multiply, (lv42,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv43 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_10_mlp_down_proj_q_weight4, transformer_h_10_mlp_down_proj_q_scale4, lv10_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv42_1 = R.call_tir(cls.fuse_add_norm_decode, (lv43, lv41_1, transformer_h_11_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv43_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv42_1[1]
            rms_norm217: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv42_1[0]
            lv44 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_11_mixer_qkv_proj_q_weight4, transformer_h_11_mixer_qkv_proj_q_scale4, rms_norm217), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape428 = R.call_tir(cls.reshape, (lv44,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape429 = R.call_tir(cls.reshape1, (reshape428,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv541 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape429), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape430 = R.call_tir(cls.reshape2, (lv541,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape431 = R.call_tir(cls.reshape3, (reshape430,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv45 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_11_mixer_out_proj_q_weight4, transformer_h_11_mixer_out_proj_q_scale4, reshape431), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv44_1 = R.call_tir(cls.fuse_add_norm_decode, (lv45, lv43_1, transformer_h_11_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv45_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv44_1[1]
            rms_norm218: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv44_1[0]
            lv46 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_11_mlp_gate_up_proj_q_weight4, transformer_h_11_mlp_gate_up_proj_q_scale4, rms_norm218), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv11_2 = R.call_tir(cls.fused_split_silu_multiply, (lv46,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv47 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_11_mlp_down_proj_q_weight4, transformer_h_11_mlp_down_proj_q_scale4, lv11_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv46_1 = R.call_tir(cls.fuse_add_norm_decode, (lv47, lv45_1, transformer_h_12_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv47_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv46_1[1]
            rms_norm219: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv46_1[0]
            lv48 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_12_mixer_qkv_proj_q_weight4, transformer_h_12_mixer_qkv_proj_q_scale4, rms_norm219), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape432 = R.call_tir(cls.reshape, (lv48,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape433 = R.call_tir(cls.reshape1, (reshape432,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv546 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape433), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape434 = R.call_tir(cls.reshape2, (lv546,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape435 = R.call_tir(cls.reshape3, (reshape434,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv49 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_12_mixer_out_proj_q_weight4, transformer_h_12_mixer_out_proj_q_scale4, reshape435), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv48_1 = R.call_tir(cls.fuse_add_norm_decode, (lv49, lv47_1, transformer_h_12_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv49_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv48_1[1]
            rms_norm220: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv48_1[0]
            lv50 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_12_mlp_gate_up_proj_q_weight4, transformer_h_12_mlp_gate_up_proj_q_scale4, rms_norm220), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv12_2 = R.call_tir(cls.fused_split_silu_multiply, (lv50,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv51 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_12_mlp_down_proj_q_weight4, transformer_h_12_mlp_down_proj_q_scale4, lv12_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv50_1 = R.call_tir(cls.fuse_add_norm_decode, (lv51, lv49_1, transformer_h_13_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv51_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv50_1[1]
            rms_norm221: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv50_1[0]
            lv52 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_13_mixer_qkv_proj_q_weight4, transformer_h_13_mixer_qkv_proj_q_scale4, rms_norm221), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape436 = R.call_tir(cls.reshape, (lv52,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape437 = R.call_tir(cls.reshape1, (reshape436,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv551 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape437), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape438 = R.call_tir(cls.reshape2, (lv551,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape439 = R.call_tir(cls.reshape3, (reshape438,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv53 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_13_mixer_out_proj_q_weight4, transformer_h_13_mixer_out_proj_q_scale4, reshape439), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv52_1 = R.call_tir(cls.fuse_add_norm_decode, (lv53, lv51_1, transformer_h_13_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv53_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv52_1[1]
            rms_norm222: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv52_1[0]
            lv54 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_13_mlp_gate_up_proj_q_weight4, transformer_h_13_mlp_gate_up_proj_q_scale4, rms_norm222), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv13_2 = R.call_tir(cls.fused_split_silu_multiply, (lv54,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv55 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_13_mlp_down_proj_q_weight4, transformer_h_13_mlp_down_proj_q_scale4, lv13_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv54_1 = R.call_tir(cls.fuse_add_norm_decode, (lv55, lv53_1, transformer_h_14_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv55_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv54_1[1]
            rms_norm223: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv54_1[0]
            lv56 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_14_mixer_qkv_proj_q_weight4, transformer_h_14_mixer_qkv_proj_q_scale4, rms_norm223), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape440 = R.call_tir(cls.reshape, (lv56,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape441 = R.call_tir(cls.reshape1, (reshape440,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv556 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape441), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape442 = R.call_tir(cls.reshape2, (lv556,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape443 = R.call_tir(cls.reshape3, (reshape442,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv57 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_14_mixer_out_proj_q_weight4, transformer_h_14_mixer_out_proj_q_scale4, reshape443), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv56_1 = R.call_tir(cls.fuse_add_norm_decode, (lv57, lv55_1, transformer_h_14_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv57_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv56_1[1]
            rms_norm224: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv56_1[0]
            lv58 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_14_mlp_gate_up_proj_q_weight4, transformer_h_14_mlp_gate_up_proj_q_scale4, rms_norm224), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv14_2 = R.call_tir(cls.fused_split_silu_multiply, (lv58,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv59 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_14_mlp_down_proj_q_weight4, transformer_h_14_mlp_down_proj_q_scale4, lv14_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv58_1 = R.call_tir(cls.fuse_add_norm_decode, (lv59, lv57_1, transformer_h_15_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv59_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv58_1[1]
            rms_norm225: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv58_1[0]
            lv60 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_15_mixer_qkv_proj_q_weight4, transformer_h_15_mixer_qkv_proj_q_scale4, rms_norm225), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape444 = R.call_tir(cls.reshape, (lv60,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape445 = R.call_tir(cls.reshape1, (reshape444,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv561 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape445), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape446 = R.call_tir(cls.reshape2, (lv561,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape447 = R.call_tir(cls.reshape3, (reshape446,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv61 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_15_mixer_out_proj_q_weight4, transformer_h_15_mixer_out_proj_q_scale4, reshape447), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv60_1 = R.call_tir(cls.fuse_add_norm_decode, (lv61, lv59_1, transformer_h_15_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv61_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv60_1[1]
            rms_norm226: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv60_1[0]
            lv62 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_15_mlp_gate_up_proj_q_weight4, transformer_h_15_mlp_gate_up_proj_q_scale4, rms_norm226), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv15_2 = R.call_tir(cls.fused_split_silu_multiply, (lv62,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv63 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_15_mlp_down_proj_q_weight4, transformer_h_15_mlp_down_proj_q_scale4, lv15_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv62_1 = R.call_tir(cls.fuse_add_norm_decode, (lv63, lv61_1, transformer_h_16_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv63_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv62_1[1]
            rms_norm227: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv62_1[0]
            lv64 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_16_mixer_qkv_proj_q_weight4, transformer_h_16_mixer_qkv_proj_q_scale4, rms_norm227), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape448 = R.call_tir(cls.reshape, (lv64,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape449 = R.call_tir(cls.reshape1, (reshape448,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv566 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape449), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape450 = R.call_tir(cls.reshape2, (lv566,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape451 = R.call_tir(cls.reshape3, (reshape450,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv65 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_16_mixer_out_proj_q_weight4, transformer_h_16_mixer_out_proj_q_scale4, reshape451), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv64_1 = R.call_tir(cls.fuse_add_norm_decode, (lv65, lv63_1, transformer_h_16_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv65_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv64_1[1]
            rms_norm228: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv64_1[0]
            lv66 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_16_mlp_gate_up_proj_q_weight4, transformer_h_16_mlp_gate_up_proj_q_scale4, rms_norm228), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv16_2 = R.call_tir(cls.fused_split_silu_multiply, (lv66,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv67 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_16_mlp_down_proj_q_weight4, transformer_h_16_mlp_down_proj_q_scale4, lv16_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv66_1 = R.call_tir(cls.fuse_add_norm_decode, (lv67, lv65_1, transformer_h_17_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv67_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv66_1[1]
            rms_norm229: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv66_1[0]
            lv68 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_17_mixer_qkv_proj_q_weight4, transformer_h_17_mixer_qkv_proj_q_scale4, rms_norm229), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape452 = R.call_tir(cls.reshape, (lv68,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape453 = R.call_tir(cls.reshape1, (reshape452,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv571 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape453), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape454 = R.call_tir(cls.reshape2, (lv571,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape455 = R.call_tir(cls.reshape3, (reshape454,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv69 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_17_mixer_out_proj_q_weight4, transformer_h_17_mixer_out_proj_q_scale4, reshape455), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv68_1 = R.call_tir(cls.fuse_add_norm_decode, (lv69, lv67_1, transformer_h_17_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv69_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv68_1[1]
            rms_norm230: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv68_1[0]
            lv70 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_17_mlp_gate_up_proj_q_weight4, transformer_h_17_mlp_gate_up_proj_q_scale4, rms_norm230), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv17_2 = R.call_tir(cls.fused_split_silu_multiply, (lv70,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv71 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_17_mlp_down_proj_q_weight4, transformer_h_17_mlp_down_proj_q_scale4, lv17_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv70_1 = R.call_tir(cls.fuse_add_norm_decode, (lv71, lv69_1, transformer_h_18_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv71_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv70_1[1]
            rms_norm231: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv70_1[0]
            lv72 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_18_mixer_qkv_proj_q_weight4, transformer_h_18_mixer_qkv_proj_q_scale4, rms_norm231), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape456 = R.call_tir(cls.reshape, (lv72,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape457 = R.call_tir(cls.reshape1, (reshape456,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv576 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape457), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape458 = R.call_tir(cls.reshape2, (lv576,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape459 = R.call_tir(cls.reshape3, (reshape458,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv73 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_18_mixer_out_proj_q_weight4, transformer_h_18_mixer_out_proj_q_scale4, reshape459), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv72_1 = R.call_tir(cls.fuse_add_norm_decode, (lv73, lv71_1, transformer_h_18_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv73_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv72_1[1]
            rms_norm232: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv72_1[0]
            lv74 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_18_mlp_gate_up_proj_q_weight4, transformer_h_18_mlp_gate_up_proj_q_scale4, rms_norm232), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv18_2 = R.call_tir(cls.fused_split_silu_multiply, (lv74,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv75 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_18_mlp_down_proj_q_weight4, transformer_h_18_mlp_down_proj_q_scale4, lv18_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv74_1 = R.call_tir(cls.fuse_add_norm_decode, (lv75, lv73_1, transformer_h_19_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv75_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv74_1[1]
            rms_norm233: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv74_1[0]
            lv76 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_19_mixer_qkv_proj_q_weight4, transformer_h_19_mixer_qkv_proj_q_scale4, rms_norm233), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape460 = R.call_tir(cls.reshape, (lv76,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape461 = R.call_tir(cls.reshape1, (reshape460,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv581 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape461), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape462 = R.call_tir(cls.reshape2, (lv581,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape463 = R.call_tir(cls.reshape3, (reshape462,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv77 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_19_mixer_out_proj_q_weight4, transformer_h_19_mixer_out_proj_q_scale4, reshape463), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv76_1 = R.call_tir(cls.fuse_add_norm_decode, (lv77, lv75_1, transformer_h_19_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv77_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv76_1[1]
            rms_norm234: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv76_1[0]
            lv78 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_19_mlp_gate_up_proj_q_weight4, transformer_h_19_mlp_gate_up_proj_q_scale4, rms_norm234), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv19_2 = R.call_tir(cls.fused_split_silu_multiply, (lv78,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv79 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_19_mlp_down_proj_q_weight4, transformer_h_19_mlp_down_proj_q_scale4, lv19_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv78_1 = R.call_tir(cls.fuse_add_norm_decode, (lv79, lv77_1, transformer_h_20_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv79_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv78_1[1]
            rms_norm235: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv78_1[0]
            lv80 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_20_mixer_qkv_proj_q_weight4, transformer_h_20_mixer_qkv_proj_q_scale4, rms_norm235), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape464 = R.call_tir(cls.reshape, (lv80,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape465 = R.call_tir(cls.reshape1, (reshape464,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv586 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape465), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape466 = R.call_tir(cls.reshape2, (lv586,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape467 = R.call_tir(cls.reshape3, (reshape466,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv81 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_20_mixer_out_proj_q_weight4, transformer_h_20_mixer_out_proj_q_scale4, reshape467), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv80_1 = R.call_tir(cls.fuse_add_norm_decode, (lv81, lv79_1, transformer_h_20_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv81_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv80_1[1]
            rms_norm236: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv80_1[0]
            lv82 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_20_mlp_gate_up_proj_q_weight4, transformer_h_20_mlp_gate_up_proj_q_scale4, rms_norm236), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv20_2 = R.call_tir(cls.fused_split_silu_multiply, (lv82,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv83 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_20_mlp_down_proj_q_weight4, transformer_h_20_mlp_down_proj_q_scale4, lv20_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv82_1 = R.call_tir(cls.fuse_add_norm_decode, (lv83, lv81_1, transformer_h_21_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv83_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv82_1[1]
            rms_norm237: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv82_1[0]
            lv84 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_21_mixer_qkv_proj_q_weight4, transformer_h_21_mixer_qkv_proj_q_scale4, rms_norm237), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape468 = R.call_tir(cls.reshape, (lv84,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape469 = R.call_tir(cls.reshape1, (reshape468,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv591 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape469), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape470 = R.call_tir(cls.reshape2, (lv591,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape471 = R.call_tir(cls.reshape3, (reshape470,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv85 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_21_mixer_out_proj_q_weight4, transformer_h_21_mixer_out_proj_q_scale4, reshape471), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv84_1 = R.call_tir(cls.fuse_add_norm_decode, (lv85, lv83_1, transformer_h_21_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv85_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv84_1[1]
            rms_norm238: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv84_1[0]
            lv86 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_21_mlp_gate_up_proj_q_weight4, transformer_h_21_mlp_gate_up_proj_q_scale4, rms_norm238), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv21_2 = R.call_tir(cls.fused_split_silu_multiply, (lv86,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv87 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_21_mlp_down_proj_q_weight4, transformer_h_21_mlp_down_proj_q_scale4, lv21_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv86_1 = R.call_tir(cls.fuse_add_norm_decode, (lv87, lv85_1, transformer_h_22_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv87_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv86_1[1]
            rms_norm239: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv86_1[0]
            lv88 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_22_mixer_qkv_proj_q_weight4, transformer_h_22_mixer_qkv_proj_q_scale4, rms_norm239), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape472 = R.call_tir(cls.reshape, (lv88,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape473 = R.call_tir(cls.reshape1, (reshape472,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv596 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape473), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape474 = R.call_tir(cls.reshape2, (lv596,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape475 = R.call_tir(cls.reshape3, (reshape474,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv89 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_22_mixer_out_proj_q_weight4, transformer_h_22_mixer_out_proj_q_scale4, reshape475), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv88_1 = R.call_tir(cls.fuse_add_norm_decode, (lv89, lv87_1, transformer_h_22_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv89_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv88_1[1]
            rms_norm240: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv88_1[0]
            lv90 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_22_mlp_gate_up_proj_q_weight4, transformer_h_22_mlp_gate_up_proj_q_scale4, rms_norm240), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv22_2 = R.call_tir(cls.fused_split_silu_multiply, (lv90,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv91 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_22_mlp_down_proj_q_weight4, transformer_h_22_mlp_down_proj_q_scale4, lv22_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv90_1 = R.call_tir(cls.fuse_add_norm_decode, (lv91, lv89_1, transformer_h_23_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv91_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv90_1[1]
            rms_norm241: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv90_1[0]
            lv92 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_23_mixer_qkv_proj_q_weight4, transformer_h_23_mixer_qkv_proj_q_scale4, rms_norm241), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape476 = R.call_tir(cls.reshape, (lv92,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape477 = R.call_tir(cls.reshape1, (reshape476,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv601 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape477), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape478 = R.call_tir(cls.reshape2, (lv601,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape479 = R.call_tir(cls.reshape3, (reshape478,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv93 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_23_mixer_out_proj_q_weight4, transformer_h_23_mixer_out_proj_q_scale4, reshape479), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv92_1 = R.call_tir(cls.fuse_add_norm_decode, (lv93, lv91_1, transformer_h_23_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv93_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv92_1[1]
            rms_norm242: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv92_1[0]
            lv94 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_23_mlp_gate_up_proj_q_weight4, transformer_h_23_mlp_gate_up_proj_q_scale4, rms_norm242), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv23_2 = R.call_tir(cls.fused_split_silu_multiply, (lv94,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv95 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_23_mlp_down_proj_q_weight4, transformer_h_23_mlp_down_proj_q_scale4, lv23_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv94_1 = R.call_tir(cls.fuse_add_norm_decode, (lv95, lv93_1, transformer_h_24_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv95_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv94_1[1]
            rms_norm243: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv94_1[0]
            lv96 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_24_mixer_qkv_proj_q_weight4, transformer_h_24_mixer_qkv_proj_q_scale4, rms_norm243), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape480 = R.call_tir(cls.reshape, (lv96,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape481 = R.call_tir(cls.reshape1, (reshape480,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv606 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape481), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape482 = R.call_tir(cls.reshape2, (lv606,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape483 = R.call_tir(cls.reshape3, (reshape482,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv97 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_24_mixer_out_proj_q_weight4, transformer_h_24_mixer_out_proj_q_scale4, reshape483), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv96_1 = R.call_tir(cls.fuse_add_norm_decode, (lv97, lv95_1, transformer_h_24_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv97_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv96_1[1]
            rms_norm244: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv96_1[0]
            lv98 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_24_mlp_gate_up_proj_q_weight4, transformer_h_24_mlp_gate_up_proj_q_scale4, rms_norm244), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv24_2 = R.call_tir(cls.fused_split_silu_multiply, (lv98,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv99 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_24_mlp_down_proj_q_weight4, transformer_h_24_mlp_down_proj_q_scale4, lv24_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv98_1 = R.call_tir(cls.fuse_add_norm_decode, (lv99, lv97_1, transformer_h_25_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv99_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv98_1[1]
            rms_norm245: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv98_1[0]
            lv100 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_25_mixer_qkv_proj_q_weight4, transformer_h_25_mixer_qkv_proj_q_scale4, rms_norm245), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape484 = R.call_tir(cls.reshape, (lv100,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape485 = R.call_tir(cls.reshape1, (reshape484,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv611 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape485), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape486 = R.call_tir(cls.reshape2, (lv611,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape487 = R.call_tir(cls.reshape3, (reshape486,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv101 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_25_mixer_out_proj_q_weight4, transformer_h_25_mixer_out_proj_q_scale4, reshape487), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv100_1 = R.call_tir(cls.fuse_add_norm_decode, (lv101, lv99_1, transformer_h_25_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv101_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv100_1[1]
            rms_norm246: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv100_1[0]
            lv102 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_25_mlp_gate_up_proj_q_weight4, transformer_h_25_mlp_gate_up_proj_q_scale4, rms_norm246), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv25_2 = R.call_tir(cls.fused_split_silu_multiply, (lv102,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv103 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_25_mlp_down_proj_q_weight4, transformer_h_25_mlp_down_proj_q_scale4, lv25_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv102_1 = R.call_tir(cls.fuse_add_norm_decode, (lv103, lv101_1, transformer_h_26_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv103_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv102_1[1]
            rms_norm247: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv102_1[0]
            lv104 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_26_mixer_qkv_proj_q_weight4, transformer_h_26_mixer_qkv_proj_q_scale4, rms_norm247), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape488 = R.call_tir(cls.reshape, (lv104,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape489 = R.call_tir(cls.reshape1, (reshape488,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv616 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape489), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape490 = R.call_tir(cls.reshape2, (lv616,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape491 = R.call_tir(cls.reshape3, (reshape490,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv105 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_26_mixer_out_proj_q_weight4, transformer_h_26_mixer_out_proj_q_scale4, reshape491), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv104_1 = R.call_tir(cls.fuse_add_norm_decode, (lv105, lv103_1, transformer_h_26_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv105_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv104_1[1]
            rms_norm248: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv104_1[0]
            lv106 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_26_mlp_gate_up_proj_q_weight4, transformer_h_26_mlp_gate_up_proj_q_scale4, rms_norm248), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv26_2 = R.call_tir(cls.fused_split_silu_multiply, (lv106,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv107 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_26_mlp_down_proj_q_weight4, transformer_h_26_mlp_down_proj_q_scale4, lv26_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv106_1 = R.call_tir(cls.fuse_add_norm_decode, (lv107, lv105_1, transformer_h_27_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv107_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv106_1[1]
            rms_norm249: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv106_1[0]
            lv108 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_27_mixer_qkv_proj_q_weight4, transformer_h_27_mixer_qkv_proj_q_scale4, rms_norm249), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape492 = R.call_tir(cls.reshape, (lv108,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape493 = R.call_tir(cls.reshape1, (reshape492,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv621 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape493), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape494 = R.call_tir(cls.reshape2, (lv621,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape495 = R.call_tir(cls.reshape3, (reshape494,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv109 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_27_mixer_out_proj_q_weight4, transformer_h_27_mixer_out_proj_q_scale4, reshape495), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv108_1 = R.call_tir(cls.fuse_add_norm_decode, (lv109, lv107_1, transformer_h_27_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv109_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv108_1[1]
            rms_norm250: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv108_1[0]
            lv110 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_27_mlp_gate_up_proj_q_weight4, transformer_h_27_mlp_gate_up_proj_q_scale4, rms_norm250), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv27_2 = R.call_tir(cls.fused_split_silu_multiply, (lv110,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv111 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_27_mlp_down_proj_q_weight4, transformer_h_27_mlp_down_proj_q_scale4, lv27_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv110_1 = R.call_tir(cls.fuse_add_norm_decode, (lv111, lv109_1, transformer_h_28_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv111_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv110_1[1]
            rms_norm251: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv110_1[0]
            lv112 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_28_mixer_qkv_proj_q_weight4, transformer_h_28_mixer_qkv_proj_q_scale4, rms_norm251), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape496 = R.call_tir(cls.reshape, (lv112,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape497 = R.call_tir(cls.reshape1, (reshape496,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv626 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape497), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape498 = R.call_tir(cls.reshape2, (lv626,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape499 = R.call_tir(cls.reshape3, (reshape498,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv113 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_28_mixer_out_proj_q_weight4, transformer_h_28_mixer_out_proj_q_scale4, reshape499), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv112_1 = R.call_tir(cls.fuse_add_norm_decode, (lv113, lv111_1, transformer_h_28_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv113_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv112_1[1]
            rms_norm252: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv112_1[0]
            lv114 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_28_mlp_gate_up_proj_q_weight4, transformer_h_28_mlp_gate_up_proj_q_scale4, rms_norm252), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv28_2 = R.call_tir(cls.fused_split_silu_multiply, (lv114,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv115 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_28_mlp_down_proj_q_weight4, transformer_h_28_mlp_down_proj_q_scale4, lv28_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv114_1 = R.call_tir(cls.fuse_add_norm_decode, (lv115, lv113_1, transformer_h_29_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv115_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv114_1[1]
            rms_norm253: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv114_1[0]
            lv116 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_29_mixer_qkv_proj_q_weight4, transformer_h_29_mixer_qkv_proj_q_scale4, rms_norm253), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape500 = R.call_tir(cls.reshape, (lv116,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape501 = R.call_tir(cls.reshape1, (reshape500,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv631 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape501), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape502 = R.call_tir(cls.reshape2, (lv631,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape503 = R.call_tir(cls.reshape3, (reshape502,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv117 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_29_mixer_out_proj_q_weight4, transformer_h_29_mixer_out_proj_q_scale4, reshape503), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv116_1 = R.call_tir(cls.fuse_add_norm_decode, (lv117, lv115_1, transformer_h_29_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv117_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv116_1[1]
            rms_norm254: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv116_1[0]
            lv118 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_29_mlp_gate_up_proj_q_weight4, transformer_h_29_mlp_gate_up_proj_q_scale4, rms_norm254), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv29_2 = R.call_tir(cls.fused_split_silu_multiply, (lv118,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv119 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_29_mlp_down_proj_q_weight4, transformer_h_29_mlp_down_proj_q_scale4, lv29_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv118_1 = R.call_tir(cls.fuse_add_norm_decode, (lv119, lv117_1, transformer_h_30_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv119_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv118_1[1]
            rms_norm255: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv118_1[0]
            lv120 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_30_mixer_qkv_proj_q_weight4, transformer_h_30_mixer_qkv_proj_q_scale4, rms_norm255), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape504 = R.call_tir(cls.reshape, (lv120,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape505 = R.call_tir(cls.reshape1, (reshape504,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv636 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape505), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape506 = R.call_tir(cls.reshape2, (lv636,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape507 = R.call_tir(cls.reshape3, (reshape506,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv121 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_30_mixer_out_proj_q_weight4, transformer_h_30_mixer_out_proj_q_scale4, reshape507), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv120_1 = R.call_tir(cls.fuse_add_norm_decode, (lv121, lv119_1, transformer_h_30_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv121_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv120_1[1]
            rms_norm256: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv120_1[0]
            lv122 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_30_mlp_gate_up_proj_q_weight4, transformer_h_30_mlp_gate_up_proj_q_scale4, rms_norm256), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv30_2 = R.call_tir(cls.fused_split_silu_multiply, (lv122,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv123 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_30_mlp_down_proj_q_weight4, transformer_h_30_mlp_down_proj_q_scale4, lv30_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv122_1 = R.call_tir(cls.fuse_add_norm_decode, (lv123, lv121_1, transformer_h_31_ln_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv123_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv122_1[1]
            rms_norm257: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv122_1[0]
            lv124 = R.call_tir(cls.fused_dequantize1_NT_matmul, (transformer_h_31_mixer_qkv_proj_q_weight4, transformer_h_31_mixer_qkv_proj_q_scale4, rms_norm257), out_sinfo=R.Tensor((batch_size, 1, 9216), dtype="float16"))
            reshape508 = R.call_tir(cls.reshape, (lv124,), out_sinfo=R.Tensor((batch_size, 1, 96, 96), dtype="float16"))
            reshape509 = R.call_tir(cls.reshape1, (reshape508,), out_sinfo=R.Tensor((batch_size, 96, 96), dtype="float16"))
            lv641 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape509), out_sinfo=R.Tensor((batch_size, 32, 96), dtype="float16"))
            reshape510 = R.call_tir(cls.reshape2, (lv641,), out_sinfo=R.Tensor((batch_size, 1, 32, 96), dtype="float16"))
            reshape511 = R.call_tir(cls.reshape3, (reshape510,), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv125 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (transformer_h_31_mixer_out_proj_q_weight4, transformer_h_31_mixer_out_proj_q_scale4, reshape511), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv124_1 = R.call_tir(cls.fuse_add_norm_decode, (lv125, lv123_1, transformer_h_31_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            lv125_1: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv124_1[1]
            rms_norm258: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv124_1[0]
            lv126 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (transformer_h_31_mlp_gate_up_proj_q_weight4, transformer_h_31_mlp_gate_up_proj_q_scale4, rms_norm258), out_sinfo=R.Tensor((batch_size, 1, 16384), dtype="float16"))
            lv31_2 = R.call_tir(cls.fused_split_silu_multiply, (lv126,), out_sinfo=R.Tensor((batch_size, 1, 8192), dtype="float16"))
            lv127 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (transformer_h_31_mlp_down_proj_q_weight4, transformer_h_31_mlp_down_proj_q_scale4, lv31_2), out_sinfo=R.Tensor((batch_size, 1, 3072), dtype="float16"))
            lv126_1 = R.call_tir(cls.fuse_add_norm_decode, (lv127, lv125_1, transformer_norm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 3072), dtype="float16"), R.Tensor((batch_size, 1, 3072), dtype="float16")])
            rms_norm259: R.Tensor((batch_size, 1, 3072), dtype="float16") = lv126_1[0]
            lv128 = R.call_tir(cls.fused_dequantize5_fused_NT_matmul4_cast, (lm_head_q_weight4, lm_head_q_scale4, rms_norm259), out_sinfo=R.Tensor((batch_size, 1, vocab_size), dtype="float32"))
            gv4: R.Tuple(R.Tensor((batch_size, 1, vocab_size), dtype="float32"), R.Object) = lv128, paged_kv_cache
            R.output(gv4)
        return gv4

    @R.function
    def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 3072), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32064, 384), dtype="uint32"), R.Tensor((32064, 96), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor(("vocab_size", 384), dtype="uint32"), R.Tensor(("vocab_size", 96), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", "vocab_size"), dtype="float32"), R.Object):
        batch_size = T.int64()
        vocab_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            transformer_h_0_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[2]
            transformer_h_0_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[3]
            transformer_h_0_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[4]
            transformer_h_0_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[5]
            transformer_h_0_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[6]
            transformer_h_0_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[7]
            transformer_h_0_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[8]
            transformer_h_0_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[9]
            transformer_h_0_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[10]
            transformer_h_0_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[11]
            transformer_h_1_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[12]
            transformer_h_1_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[13]
            transformer_h_1_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[14]
            transformer_h_1_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[15]
            transformer_h_1_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[16]
            transformer_h_1_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[17]
            transformer_h_1_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[18]
            transformer_h_1_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[19]
            transformer_h_1_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[20]
            transformer_h_1_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[21]
            transformer_h_2_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[22]
            transformer_h_2_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[23]
            transformer_h_2_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[24]
            transformer_h_2_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[25]
            transformer_h_2_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[26]
            transformer_h_2_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[27]
            transformer_h_2_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[28]
            transformer_h_2_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[29]
            transformer_h_2_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[30]
            transformer_h_2_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[31]
            transformer_h_3_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[32]
            transformer_h_3_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[33]
            transformer_h_3_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[34]
            transformer_h_3_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[35]
            transformer_h_3_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[36]
            transformer_h_3_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[37]
            transformer_h_3_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[38]
            transformer_h_3_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[39]
            transformer_h_3_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[40]
            transformer_h_3_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[41]
            transformer_h_4_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[42]
            transformer_h_4_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[43]
            transformer_h_4_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[44]
            transformer_h_4_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[45]
            transformer_h_4_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[46]
            transformer_h_4_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[47]
            transformer_h_4_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[48]
            transformer_h_4_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[49]
            transformer_h_4_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[50]
            transformer_h_4_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[51]
            transformer_h_5_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[52]
            transformer_h_5_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[53]
            transformer_h_5_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[54]
            transformer_h_5_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[55]
            transformer_h_5_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[56]
            transformer_h_5_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[57]
            transformer_h_5_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[58]
            transformer_h_5_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[59]
            transformer_h_5_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[60]
            transformer_h_5_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[61]
            transformer_h_6_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[62]
            transformer_h_6_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[63]
            transformer_h_6_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[64]
            transformer_h_6_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[65]
            transformer_h_6_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[66]
            transformer_h_6_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[67]
            transformer_h_6_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[68]
            transformer_h_6_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[69]
            transformer_h_6_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[70]
            transformer_h_6_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[71]
            transformer_h_7_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[72]
            transformer_h_7_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[73]
            transformer_h_7_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[74]
            transformer_h_7_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[75]
            transformer_h_7_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[76]
            transformer_h_7_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[77]
            transformer_h_7_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[78]
            transformer_h_7_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[79]
            transformer_h_7_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[80]
            transformer_h_7_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[81]
            transformer_h_8_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[82]
            transformer_h_8_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[83]
            transformer_h_8_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[84]
            transformer_h_8_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[85]
            transformer_h_8_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[86]
            transformer_h_8_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[87]
            transformer_h_8_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[88]
            transformer_h_8_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[89]
            transformer_h_8_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[90]
            transformer_h_8_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[91]
            transformer_h_9_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[92]
            transformer_h_9_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[93]
            transformer_h_9_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[94]
            transformer_h_9_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[95]
            transformer_h_9_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[96]
            transformer_h_9_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[97]
            transformer_h_9_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[98]
            transformer_h_9_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[99]
            transformer_h_9_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[100]
            transformer_h_9_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[101]
            transformer_h_10_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[102]
            transformer_h_10_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[103]
            transformer_h_10_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[104]
            transformer_h_10_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[105]
            transformer_h_10_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[106]
            transformer_h_10_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[107]
            transformer_h_10_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[108]
            transformer_h_10_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[109]
            transformer_h_10_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[110]
            transformer_h_10_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[111]
            transformer_h_11_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[112]
            transformer_h_11_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[113]
            transformer_h_11_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[114]
            transformer_h_11_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[115]
            transformer_h_11_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[116]
            transformer_h_11_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[117]
            transformer_h_11_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[118]
            transformer_h_11_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[119]
            transformer_h_11_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[120]
            transformer_h_11_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[121]
            transformer_h_12_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[122]
            transformer_h_12_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[123]
            transformer_h_12_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[124]
            transformer_h_12_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[125]
            transformer_h_12_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[126]
            transformer_h_12_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[127]
            transformer_h_12_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[128]
            transformer_h_12_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[129]
            transformer_h_12_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[130]
            transformer_h_12_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[131]
            transformer_h_13_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[132]
            transformer_h_13_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[133]
            transformer_h_13_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[134]
            transformer_h_13_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[135]
            transformer_h_13_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[136]
            transformer_h_13_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[137]
            transformer_h_13_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[138]
            transformer_h_13_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[139]
            transformer_h_13_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[140]
            transformer_h_13_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[141]
            transformer_h_14_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[142]
            transformer_h_14_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[143]
            transformer_h_14_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[144]
            transformer_h_14_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[145]
            transformer_h_14_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[146]
            transformer_h_14_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[147]
            transformer_h_14_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[148]
            transformer_h_14_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[149]
            transformer_h_14_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[150]
            transformer_h_14_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[151]
            transformer_h_15_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[152]
            transformer_h_15_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[153]
            transformer_h_15_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[154]
            transformer_h_15_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[155]
            transformer_h_15_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[156]
            transformer_h_15_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[157]
            transformer_h_15_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[158]
            transformer_h_15_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[159]
            transformer_h_15_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[160]
            transformer_h_15_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[161]
            transformer_h_16_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[162]
            transformer_h_16_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[163]
            transformer_h_16_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[164]
            transformer_h_16_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[165]
            transformer_h_16_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[166]
            transformer_h_16_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[167]
            transformer_h_16_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[168]
            transformer_h_16_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[169]
            transformer_h_16_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[170]
            transformer_h_16_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[171]
            transformer_h_17_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[172]
            transformer_h_17_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[173]
            transformer_h_17_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[174]
            transformer_h_17_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[175]
            transformer_h_17_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[176]
            transformer_h_17_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[177]
            transformer_h_17_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[178]
            transformer_h_17_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[179]
            transformer_h_17_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[180]
            transformer_h_17_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[181]
            transformer_h_18_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[182]
            transformer_h_18_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[183]
            transformer_h_18_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[184]
            transformer_h_18_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[185]
            transformer_h_18_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[186]
            transformer_h_18_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[187]
            transformer_h_18_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[188]
            transformer_h_18_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[189]
            transformer_h_18_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[190]
            transformer_h_18_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[191]
            transformer_h_19_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[192]
            transformer_h_19_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[193]
            transformer_h_19_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[194]
            transformer_h_19_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[195]
            transformer_h_19_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[196]
            transformer_h_19_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[197]
            transformer_h_19_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[198]
            transformer_h_19_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[199]
            transformer_h_19_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[200]
            transformer_h_19_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[201]
            transformer_h_20_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[202]
            transformer_h_20_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[203]
            transformer_h_20_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[204]
            transformer_h_20_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[205]
            transformer_h_20_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[206]
            transformer_h_20_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[207]
            transformer_h_20_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[208]
            transformer_h_20_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[209]
            transformer_h_20_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[210]
            transformer_h_20_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[211]
            transformer_h_21_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[212]
            transformer_h_21_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[213]
            transformer_h_21_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[214]
            transformer_h_21_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[215]
            transformer_h_21_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[216]
            transformer_h_21_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[217]
            transformer_h_21_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[218]
            transformer_h_21_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[219]
            transformer_h_21_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[220]
            transformer_h_21_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[221]
            transformer_h_22_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[222]
            transformer_h_22_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[223]
            transformer_h_22_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[224]
            transformer_h_22_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[225]
            transformer_h_22_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[226]
            transformer_h_22_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[227]
            transformer_h_22_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[228]
            transformer_h_22_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[229]
            transformer_h_22_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[230]
            transformer_h_22_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[231]
            transformer_h_23_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[232]
            transformer_h_23_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[233]
            transformer_h_23_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[234]
            transformer_h_23_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[235]
            transformer_h_23_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[236]
            transformer_h_23_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[237]
            transformer_h_23_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[238]
            transformer_h_23_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[239]
            transformer_h_23_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[240]
            transformer_h_23_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[241]
            transformer_h_24_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[242]
            transformer_h_24_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[243]
            transformer_h_24_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[244]
            transformer_h_24_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[245]
            transformer_h_24_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[246]
            transformer_h_24_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[247]
            transformer_h_24_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[248]
            transformer_h_24_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[249]
            transformer_h_24_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[250]
            transformer_h_24_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[251]
            transformer_h_25_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[252]
            transformer_h_25_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[253]
            transformer_h_25_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[254]
            transformer_h_25_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[255]
            transformer_h_25_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[256]
            transformer_h_25_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[257]
            transformer_h_25_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[258]
            transformer_h_25_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[259]
            transformer_h_25_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[260]
            transformer_h_25_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[261]
            transformer_h_26_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[262]
            transformer_h_26_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[263]
            transformer_h_26_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[264]
            transformer_h_26_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[265]
            transformer_h_26_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[266]
            transformer_h_26_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[267]
            transformer_h_26_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[268]
            transformer_h_26_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[269]
            transformer_h_26_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[270]
            transformer_h_26_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[271]
            transformer_h_27_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[272]
            transformer_h_27_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[273]
            transformer_h_27_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[274]
            transformer_h_27_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[275]
            transformer_h_27_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[276]
            transformer_h_27_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[277]
            transformer_h_27_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[278]
            transformer_h_27_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[279]
            transformer_h_27_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[280]
            transformer_h_27_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[281]
            transformer_h_28_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[282]
            transformer_h_28_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[283]
            transformer_h_28_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[284]
            transformer_h_28_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[285]
            transformer_h_28_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[286]
            transformer_h_28_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[287]
            transformer_h_28_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[288]
            transformer_h_28_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[289]
            transformer_h_28_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[290]
            transformer_h_28_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[291]
            transformer_h_29_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[292]
            transformer_h_29_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[293]
            transformer_h_29_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[294]
            transformer_h_29_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[295]
            transformer_h_29_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[296]
            transformer_h_29_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[297]
            transformer_h_29_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[298]
            transformer_h_29_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[299]
            transformer_h_29_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[300]
            transformer_h_29_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[301]
            transformer_h_30_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[302]
            transformer_h_30_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[303]
            transformer_h_30_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[304]
            transformer_h_30_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[305]
            transformer_h_30_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[306]
            transformer_h_30_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[307]
            transformer_h_30_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[308]
            transformer_h_30_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[309]
            transformer_h_30_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[310]
            transformer_h_30_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[311]
            transformer_h_31_ln_weight3: R.Tensor((3072,), dtype="float16") = packed_params[312]
            transformer_h_31_mixer_qkv_proj_q_weight3: R.Tensor((9216, 384), dtype="uint32") = packed_params[313]
            transformer_h_31_mixer_qkv_proj_q_scale3: R.Tensor((9216, 96), dtype="float16") = packed_params[314]
            transformer_h_31_mixer_out_proj_q_weight3: R.Tensor((3072, 384), dtype="uint32") = packed_params[315]
            transformer_h_31_mixer_out_proj_q_scale3: R.Tensor((3072, 96), dtype="float16") = packed_params[316]
            transformer_h_31_mlp_gate_up_proj_q_weight3: R.Tensor((16384, 384), dtype="uint32") = packed_params[317]
            transformer_h_31_mlp_gate_up_proj_q_scale3: R.Tensor((16384, 96), dtype="float16") = packed_params[318]
            transformer_h_31_mlp_down_proj_q_weight3: R.Tensor((3072, 1024), dtype="uint32") = packed_params[319]
            transformer_h_31_mlp_down_proj_q_scale3: R.Tensor((3072, 256), dtype="float16") = packed_params[320]
            transformer_h_31_post_attention_layernorm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[321]
            transformer_norm_weight3: R.Tensor((3072,), dtype="float16") = packed_params[322]
            lm_head_q_weight3: R.Tensor((vocab_size, 384), dtype="uint32") = packed_params[323]
            lm_head_q_scale3: R.Tensor((vocab_size, 96), dtype="float16") = packed_params[324]
            rms_norm130 = R.call_tir(cls.rms_norm1, (input_embeds, transformer_h_0_ln_weight3), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv129 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_0_mixer_qkv_proj_q_weight3, transformer_h_0_mixer_qkv_proj_q_scale3, rms_norm130), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape256 = R.call_tir(cls.reshape4, (lv129,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape257 = R.call_tir(cls.reshape5, (reshape256,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv325 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape257), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape258 = R.call_tir(cls.reshape6, (lv325,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape259 = R.call_tir(cls.reshape7, (reshape258,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv130 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_0_mixer_out_proj_q_weight3, transformer_h_0_mixer_out_proj_q_scale3, reshape259), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv128 = R.call_tir(cls.fuse_add_norm_prefill, (lv130, input_embeds, transformer_h_0_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv129_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv128[1]
            rms_norm131: R.Tensor((1, seq_len, 3072), dtype="float16") = lv128[0]
            lv131 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_0_mlp_gate_up_proj_q_weight3, transformer_h_0_mlp_gate_up_proj_q_scale3, rms_norm131), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv33 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv131,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv132 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_0_mlp_down_proj_q_weight3, transformer_h_0_mlp_down_proj_q_scale3, lv33), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv130_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv132, lv129_1, transformer_h_1_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv131_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv130_1[1]
            rms_norm132: R.Tensor((1, seq_len, 3072), dtype="float16") = lv130_1[0]
            lv133 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_1_mixer_qkv_proj_q_weight3, transformer_h_1_mixer_qkv_proj_q_scale3, rms_norm132), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape260 = R.call_tir(cls.reshape4, (lv133,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape261 = R.call_tir(cls.reshape5, (reshape260,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv330 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape261), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape262 = R.call_tir(cls.reshape6, (lv330,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape263 = R.call_tir(cls.reshape7, (reshape262,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv134 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_1_mixer_out_proj_q_weight3, transformer_h_1_mixer_out_proj_q_scale3, reshape263), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv132_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv134, lv131_1, transformer_h_1_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv133_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv132_1[1]
            rms_norm133: R.Tensor((1, seq_len, 3072), dtype="float16") = lv132_1[0]
            lv135 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_1_mlp_gate_up_proj_q_weight3, transformer_h_1_mlp_gate_up_proj_q_scale3, rms_norm133), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv34 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv135,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv136 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_1_mlp_down_proj_q_weight3, transformer_h_1_mlp_down_proj_q_scale3, lv34), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv134_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv136, lv133_1, transformer_h_2_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv135_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv134_1[1]
            rms_norm134: R.Tensor((1, seq_len, 3072), dtype="float16") = lv134_1[0]
            lv137 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_2_mixer_qkv_proj_q_weight3, transformer_h_2_mixer_qkv_proj_q_scale3, rms_norm134), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape264 = R.call_tir(cls.reshape4, (lv137,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape265 = R.call_tir(cls.reshape5, (reshape264,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv335 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape265), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape266 = R.call_tir(cls.reshape6, (lv335,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape267 = R.call_tir(cls.reshape7, (reshape266,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv138 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_2_mixer_out_proj_q_weight3, transformer_h_2_mixer_out_proj_q_scale3, reshape267), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv136_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv138, lv135_1, transformer_h_2_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv137_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv136_1[1]
            rms_norm135: R.Tensor((1, seq_len, 3072), dtype="float16") = lv136_1[0]
            lv139 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_2_mlp_gate_up_proj_q_weight3, transformer_h_2_mlp_gate_up_proj_q_scale3, rms_norm135), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv35 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv139,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv140 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_2_mlp_down_proj_q_weight3, transformer_h_2_mlp_down_proj_q_scale3, lv35), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv138_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv140, lv137_1, transformer_h_3_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv139_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv138_1[1]
            rms_norm136: R.Tensor((1, seq_len, 3072), dtype="float16") = lv138_1[0]
            lv141 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_3_mixer_qkv_proj_q_weight3, transformer_h_3_mixer_qkv_proj_q_scale3, rms_norm136), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape268 = R.call_tir(cls.reshape4, (lv141,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape269 = R.call_tir(cls.reshape5, (reshape268,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv340 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape269), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape270 = R.call_tir(cls.reshape6, (lv340,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape271 = R.call_tir(cls.reshape7, (reshape270,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv142 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_3_mixer_out_proj_q_weight3, transformer_h_3_mixer_out_proj_q_scale3, reshape271), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv140_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv142, lv139_1, transformer_h_3_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv141_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv140_1[1]
            rms_norm137: R.Tensor((1, seq_len, 3072), dtype="float16") = lv140_1[0]
            lv143 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_3_mlp_gate_up_proj_q_weight3, transformer_h_3_mlp_gate_up_proj_q_scale3, rms_norm137), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv36 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv143,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv144 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_3_mlp_down_proj_q_weight3, transformer_h_3_mlp_down_proj_q_scale3, lv36), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv142_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv144, lv141_1, transformer_h_4_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv143_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv142_1[1]
            rms_norm138: R.Tensor((1, seq_len, 3072), dtype="float16") = lv142_1[0]
            lv145 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_4_mixer_qkv_proj_q_weight3, transformer_h_4_mixer_qkv_proj_q_scale3, rms_norm138), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape272 = R.call_tir(cls.reshape4, (lv145,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape273 = R.call_tir(cls.reshape5, (reshape272,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv345 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape273), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape274 = R.call_tir(cls.reshape6, (lv345,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape275 = R.call_tir(cls.reshape7, (reshape274,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv146 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_4_mixer_out_proj_q_weight3, transformer_h_4_mixer_out_proj_q_scale3, reshape275), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv144_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv146, lv143_1, transformer_h_4_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv145_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv144_1[1]
            rms_norm139: R.Tensor((1, seq_len, 3072), dtype="float16") = lv144_1[0]
            lv147 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_4_mlp_gate_up_proj_q_weight3, transformer_h_4_mlp_gate_up_proj_q_scale3, rms_norm139), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv37 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv147,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv148 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_4_mlp_down_proj_q_weight3, transformer_h_4_mlp_down_proj_q_scale3, lv37), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv146_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv148, lv145_1, transformer_h_5_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv147_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv146_1[1]
            rms_norm140: R.Tensor((1, seq_len, 3072), dtype="float16") = lv146_1[0]
            lv149 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_5_mixer_qkv_proj_q_weight3, transformer_h_5_mixer_qkv_proj_q_scale3, rms_norm140), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape276 = R.call_tir(cls.reshape4, (lv149,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape277 = R.call_tir(cls.reshape5, (reshape276,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv350 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape277), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape278 = R.call_tir(cls.reshape6, (lv350,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape279 = R.call_tir(cls.reshape7, (reshape278,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv150 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_5_mixer_out_proj_q_weight3, transformer_h_5_mixer_out_proj_q_scale3, reshape279), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv148_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv150, lv147_1, transformer_h_5_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv149_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv148_1[1]
            rms_norm141: R.Tensor((1, seq_len, 3072), dtype="float16") = lv148_1[0]
            lv151 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_5_mlp_gate_up_proj_q_weight3, transformer_h_5_mlp_gate_up_proj_q_scale3, rms_norm141), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv38 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv151,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv152 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_5_mlp_down_proj_q_weight3, transformer_h_5_mlp_down_proj_q_scale3, lv38), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv150_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv152, lv149_1, transformer_h_6_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv151_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv150_1[1]
            rms_norm142: R.Tensor((1, seq_len, 3072), dtype="float16") = lv150_1[0]
            lv153 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_6_mixer_qkv_proj_q_weight3, transformer_h_6_mixer_qkv_proj_q_scale3, rms_norm142), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape280 = R.call_tir(cls.reshape4, (lv153,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape281 = R.call_tir(cls.reshape5, (reshape280,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv355 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape281), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape282 = R.call_tir(cls.reshape6, (lv355,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape283 = R.call_tir(cls.reshape7, (reshape282,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv154 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_6_mixer_out_proj_q_weight3, transformer_h_6_mixer_out_proj_q_scale3, reshape283), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv152_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv154, lv151_1, transformer_h_6_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv153_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv152_1[1]
            rms_norm143: R.Tensor((1, seq_len, 3072), dtype="float16") = lv152_1[0]
            lv155 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_6_mlp_gate_up_proj_q_weight3, transformer_h_6_mlp_gate_up_proj_q_scale3, rms_norm143), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv39 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv155,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv156 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_6_mlp_down_proj_q_weight3, transformer_h_6_mlp_down_proj_q_scale3, lv39), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv154_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv156, lv153_1, transformer_h_7_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv155_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv154_1[1]
            rms_norm144: R.Tensor((1, seq_len, 3072), dtype="float16") = lv154_1[0]
            lv157 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_7_mixer_qkv_proj_q_weight3, transformer_h_7_mixer_qkv_proj_q_scale3, rms_norm144), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape284 = R.call_tir(cls.reshape4, (lv157,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape285 = R.call_tir(cls.reshape5, (reshape284,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv360 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape285), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape286 = R.call_tir(cls.reshape6, (lv360,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape287 = R.call_tir(cls.reshape7, (reshape286,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv158 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_7_mixer_out_proj_q_weight3, transformer_h_7_mixer_out_proj_q_scale3, reshape287), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv156_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv158, lv155_1, transformer_h_7_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv157_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv156_1[1]
            rms_norm145: R.Tensor((1, seq_len, 3072), dtype="float16") = lv156_1[0]
            lv159 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_7_mlp_gate_up_proj_q_weight3, transformer_h_7_mlp_gate_up_proj_q_scale3, rms_norm145), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv40 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv159,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv160 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_7_mlp_down_proj_q_weight3, transformer_h_7_mlp_down_proj_q_scale3, lv40), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv158_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv160, lv157_1, transformer_h_8_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv159_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv158_1[1]
            rms_norm146: R.Tensor((1, seq_len, 3072), dtype="float16") = lv158_1[0]
            lv161 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_8_mixer_qkv_proj_q_weight3, transformer_h_8_mixer_qkv_proj_q_scale3, rms_norm146), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape288 = R.call_tir(cls.reshape4, (lv161,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape289 = R.call_tir(cls.reshape5, (reshape288,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv365 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape289), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape290 = R.call_tir(cls.reshape6, (lv365,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape291 = R.call_tir(cls.reshape7, (reshape290,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv162 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_8_mixer_out_proj_q_weight3, transformer_h_8_mixer_out_proj_q_scale3, reshape291), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv160_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv162, lv159_1, transformer_h_8_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv161_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv160_1[1]
            rms_norm147: R.Tensor((1, seq_len, 3072), dtype="float16") = lv160_1[0]
            lv163 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_8_mlp_gate_up_proj_q_weight3, transformer_h_8_mlp_gate_up_proj_q_scale3, rms_norm147), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv41 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv163,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv164 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_8_mlp_down_proj_q_weight3, transformer_h_8_mlp_down_proj_q_scale3, lv41), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv162_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv164, lv161_1, transformer_h_9_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv163_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv162_1[1]
            rms_norm148: R.Tensor((1, seq_len, 3072), dtype="float16") = lv162_1[0]
            lv165 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_9_mixer_qkv_proj_q_weight3, transformer_h_9_mixer_qkv_proj_q_scale3, rms_norm148), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape292 = R.call_tir(cls.reshape4, (lv165,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape293 = R.call_tir(cls.reshape5, (reshape292,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv370 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape293), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape294 = R.call_tir(cls.reshape6, (lv370,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape295 = R.call_tir(cls.reshape7, (reshape294,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv166 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_9_mixer_out_proj_q_weight3, transformer_h_9_mixer_out_proj_q_scale3, reshape295), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv164_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv166, lv163_1, transformer_h_9_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv165_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv164_1[1]
            rms_norm149: R.Tensor((1, seq_len, 3072), dtype="float16") = lv164_1[0]
            lv167 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_9_mlp_gate_up_proj_q_weight3, transformer_h_9_mlp_gate_up_proj_q_scale3, rms_norm149), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv42 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv167,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv168 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_9_mlp_down_proj_q_weight3, transformer_h_9_mlp_down_proj_q_scale3, lv42), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv166_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv168, lv165_1, transformer_h_10_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv167_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv166_1[1]
            rms_norm150: R.Tensor((1, seq_len, 3072), dtype="float16") = lv166_1[0]
            lv169 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_10_mixer_qkv_proj_q_weight3, transformer_h_10_mixer_qkv_proj_q_scale3, rms_norm150), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape296 = R.call_tir(cls.reshape4, (lv169,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape297 = R.call_tir(cls.reshape5, (reshape296,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv375 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape297), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape298 = R.call_tir(cls.reshape6, (lv375,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape299 = R.call_tir(cls.reshape7, (reshape298,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv170 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_10_mixer_out_proj_q_weight3, transformer_h_10_mixer_out_proj_q_scale3, reshape299), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv168_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv170, lv167_1, transformer_h_10_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv169_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv168_1[1]
            rms_norm151: R.Tensor((1, seq_len, 3072), dtype="float16") = lv168_1[0]
            lv171 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_10_mlp_gate_up_proj_q_weight3, transformer_h_10_mlp_gate_up_proj_q_scale3, rms_norm151), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv43 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv171,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv172 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_10_mlp_down_proj_q_weight3, transformer_h_10_mlp_down_proj_q_scale3, lv43), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv170_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv172, lv169_1, transformer_h_11_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv171_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv170_1[1]
            rms_norm152: R.Tensor((1, seq_len, 3072), dtype="float16") = lv170_1[0]
            lv173 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_11_mixer_qkv_proj_q_weight3, transformer_h_11_mixer_qkv_proj_q_scale3, rms_norm152), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape300 = R.call_tir(cls.reshape4, (lv173,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape301 = R.call_tir(cls.reshape5, (reshape300,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv380 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape301), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape302 = R.call_tir(cls.reshape6, (lv380,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape303 = R.call_tir(cls.reshape7, (reshape302,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv174 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_11_mixer_out_proj_q_weight3, transformer_h_11_mixer_out_proj_q_scale3, reshape303), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv172_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv174, lv171_1, transformer_h_11_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv173_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv172_1[1]
            rms_norm153: R.Tensor((1, seq_len, 3072), dtype="float16") = lv172_1[0]
            lv175 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_11_mlp_gate_up_proj_q_weight3, transformer_h_11_mlp_gate_up_proj_q_scale3, rms_norm153), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv44 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv175,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv176 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_11_mlp_down_proj_q_weight3, transformer_h_11_mlp_down_proj_q_scale3, lv44), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv174_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv176, lv173_1, transformer_h_12_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv175_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv174_1[1]
            rms_norm154: R.Tensor((1, seq_len, 3072), dtype="float16") = lv174_1[0]
            lv177 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_12_mixer_qkv_proj_q_weight3, transformer_h_12_mixer_qkv_proj_q_scale3, rms_norm154), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape304 = R.call_tir(cls.reshape4, (lv177,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape305 = R.call_tir(cls.reshape5, (reshape304,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv385 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape305), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape306 = R.call_tir(cls.reshape6, (lv385,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape307 = R.call_tir(cls.reshape7, (reshape306,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv178 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_12_mixer_out_proj_q_weight3, transformer_h_12_mixer_out_proj_q_scale3, reshape307), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv176_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv178, lv175_1, transformer_h_12_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv177_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv176_1[1]
            rms_norm155: R.Tensor((1, seq_len, 3072), dtype="float16") = lv176_1[0]
            lv179 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_12_mlp_gate_up_proj_q_weight3, transformer_h_12_mlp_gate_up_proj_q_scale3, rms_norm155), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv45 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv179,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv180 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_12_mlp_down_proj_q_weight3, transformer_h_12_mlp_down_proj_q_scale3, lv45), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv178_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv180, lv177_1, transformer_h_13_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv179_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv178_1[1]
            rms_norm156: R.Tensor((1, seq_len, 3072), dtype="float16") = lv178_1[0]
            lv181 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_13_mixer_qkv_proj_q_weight3, transformer_h_13_mixer_qkv_proj_q_scale3, rms_norm156), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape308 = R.call_tir(cls.reshape4, (lv181,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape309 = R.call_tir(cls.reshape5, (reshape308,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv390 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape309), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape310 = R.call_tir(cls.reshape6, (lv390,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape311 = R.call_tir(cls.reshape7, (reshape310,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv182 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_13_mixer_out_proj_q_weight3, transformer_h_13_mixer_out_proj_q_scale3, reshape311), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv180_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv182, lv179_1, transformer_h_13_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv181_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv180_1[1]
            rms_norm157: R.Tensor((1, seq_len, 3072), dtype="float16") = lv180_1[0]
            lv183 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_13_mlp_gate_up_proj_q_weight3, transformer_h_13_mlp_gate_up_proj_q_scale3, rms_norm157), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv46 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv183,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv184 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_13_mlp_down_proj_q_weight3, transformer_h_13_mlp_down_proj_q_scale3, lv46), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv182_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv184, lv181_1, transformer_h_14_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv183_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv182_1[1]
            rms_norm158: R.Tensor((1, seq_len, 3072), dtype="float16") = lv182_1[0]
            lv185 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_14_mixer_qkv_proj_q_weight3, transformer_h_14_mixer_qkv_proj_q_scale3, rms_norm158), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape312 = R.call_tir(cls.reshape4, (lv185,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape313 = R.call_tir(cls.reshape5, (reshape312,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv395 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape313), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape314 = R.call_tir(cls.reshape6, (lv395,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape315 = R.call_tir(cls.reshape7, (reshape314,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv186 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_14_mixer_out_proj_q_weight3, transformer_h_14_mixer_out_proj_q_scale3, reshape315), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv184_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv186, lv183_1, transformer_h_14_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv185_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv184_1[1]
            rms_norm159: R.Tensor((1, seq_len, 3072), dtype="float16") = lv184_1[0]
            lv187 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_14_mlp_gate_up_proj_q_weight3, transformer_h_14_mlp_gate_up_proj_q_scale3, rms_norm159), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv47 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv187,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv188 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_14_mlp_down_proj_q_weight3, transformer_h_14_mlp_down_proj_q_scale3, lv47), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv186_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv188, lv185_1, transformer_h_15_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv187_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv186_1[1]
            rms_norm160: R.Tensor((1, seq_len, 3072), dtype="float16") = lv186_1[0]
            lv189 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_15_mixer_qkv_proj_q_weight3, transformer_h_15_mixer_qkv_proj_q_scale3, rms_norm160), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape316 = R.call_tir(cls.reshape4, (lv189,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape317 = R.call_tir(cls.reshape5, (reshape316,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv400 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape317), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape318 = R.call_tir(cls.reshape6, (lv400,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape319 = R.call_tir(cls.reshape7, (reshape318,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv190 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_15_mixer_out_proj_q_weight3, transformer_h_15_mixer_out_proj_q_scale3, reshape319), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv188_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv190, lv187_1, transformer_h_15_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv189_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv188_1[1]
            rms_norm161: R.Tensor((1, seq_len, 3072), dtype="float16") = lv188_1[0]
            lv191 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_15_mlp_gate_up_proj_q_weight3, transformer_h_15_mlp_gate_up_proj_q_scale3, rms_norm161), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv48 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv191,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv192 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_15_mlp_down_proj_q_weight3, transformer_h_15_mlp_down_proj_q_scale3, lv48), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv190_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv192, lv189_1, transformer_h_16_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv191_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv190_1[1]
            rms_norm162: R.Tensor((1, seq_len, 3072), dtype="float16") = lv190_1[0]
            lv193 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_16_mixer_qkv_proj_q_weight3, transformer_h_16_mixer_qkv_proj_q_scale3, rms_norm162), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape320 = R.call_tir(cls.reshape4, (lv193,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape321 = R.call_tir(cls.reshape5, (reshape320,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv405 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape321), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape322 = R.call_tir(cls.reshape6, (lv405,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape323 = R.call_tir(cls.reshape7, (reshape322,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv194 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_16_mixer_out_proj_q_weight3, transformer_h_16_mixer_out_proj_q_scale3, reshape323), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv192_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv194, lv191_1, transformer_h_16_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv193_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv192_1[1]
            rms_norm163: R.Tensor((1, seq_len, 3072), dtype="float16") = lv192_1[0]
            lv195 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_16_mlp_gate_up_proj_q_weight3, transformer_h_16_mlp_gate_up_proj_q_scale3, rms_norm163), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv49 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv195,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv196 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_16_mlp_down_proj_q_weight3, transformer_h_16_mlp_down_proj_q_scale3, lv49), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv194_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv196, lv193_1, transformer_h_17_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv195_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv194_1[1]
            rms_norm164: R.Tensor((1, seq_len, 3072), dtype="float16") = lv194_1[0]
            lv197 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_17_mixer_qkv_proj_q_weight3, transformer_h_17_mixer_qkv_proj_q_scale3, rms_norm164), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape324 = R.call_tir(cls.reshape4, (lv197,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape325 = R.call_tir(cls.reshape5, (reshape324,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv410 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape325), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape326 = R.call_tir(cls.reshape6, (lv410,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape327 = R.call_tir(cls.reshape7, (reshape326,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv198 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_17_mixer_out_proj_q_weight3, transformer_h_17_mixer_out_proj_q_scale3, reshape327), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv196_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv198, lv195_1, transformer_h_17_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv197_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv196_1[1]
            rms_norm165: R.Tensor((1, seq_len, 3072), dtype="float16") = lv196_1[0]
            lv199 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_17_mlp_gate_up_proj_q_weight3, transformer_h_17_mlp_gate_up_proj_q_scale3, rms_norm165), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv50 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv199,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv200 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_17_mlp_down_proj_q_weight3, transformer_h_17_mlp_down_proj_q_scale3, lv50), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv198_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv200, lv197_1, transformer_h_18_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv199_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv198_1[1]
            rms_norm166: R.Tensor((1, seq_len, 3072), dtype="float16") = lv198_1[0]
            lv201 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_18_mixer_qkv_proj_q_weight3, transformer_h_18_mixer_qkv_proj_q_scale3, rms_norm166), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape328 = R.call_tir(cls.reshape4, (lv201,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape329 = R.call_tir(cls.reshape5, (reshape328,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv415 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape329), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape330 = R.call_tir(cls.reshape6, (lv415,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape331 = R.call_tir(cls.reshape7, (reshape330,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv202 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_18_mixer_out_proj_q_weight3, transformer_h_18_mixer_out_proj_q_scale3, reshape331), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv200_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv202, lv199_1, transformer_h_18_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv201_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv200_1[1]
            rms_norm167: R.Tensor((1, seq_len, 3072), dtype="float16") = lv200_1[0]
            lv203 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_18_mlp_gate_up_proj_q_weight3, transformer_h_18_mlp_gate_up_proj_q_scale3, rms_norm167), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv51 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv203,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv204 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_18_mlp_down_proj_q_weight3, transformer_h_18_mlp_down_proj_q_scale3, lv51), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv202_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv204, lv201_1, transformer_h_19_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv203_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv202_1[1]
            rms_norm168: R.Tensor((1, seq_len, 3072), dtype="float16") = lv202_1[0]
            lv205 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_19_mixer_qkv_proj_q_weight3, transformer_h_19_mixer_qkv_proj_q_scale3, rms_norm168), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape332 = R.call_tir(cls.reshape4, (lv205,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape333 = R.call_tir(cls.reshape5, (reshape332,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv420 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape333), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape334 = R.call_tir(cls.reshape6, (lv420,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape335 = R.call_tir(cls.reshape7, (reshape334,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv206 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_19_mixer_out_proj_q_weight3, transformer_h_19_mixer_out_proj_q_scale3, reshape335), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv204_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv206, lv203_1, transformer_h_19_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv205_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv204_1[1]
            rms_norm169: R.Tensor((1, seq_len, 3072), dtype="float16") = lv204_1[0]
            lv207 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_19_mlp_gate_up_proj_q_weight3, transformer_h_19_mlp_gate_up_proj_q_scale3, rms_norm169), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv52 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv207,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv208 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_19_mlp_down_proj_q_weight3, transformer_h_19_mlp_down_proj_q_scale3, lv52), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv206_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv208, lv205_1, transformer_h_20_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv207_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv206_1[1]
            rms_norm170: R.Tensor((1, seq_len, 3072), dtype="float16") = lv206_1[0]
            lv209 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_20_mixer_qkv_proj_q_weight3, transformer_h_20_mixer_qkv_proj_q_scale3, rms_norm170), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape336 = R.call_tir(cls.reshape4, (lv209,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape337 = R.call_tir(cls.reshape5, (reshape336,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv425 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape337), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape338 = R.call_tir(cls.reshape6, (lv425,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape339 = R.call_tir(cls.reshape7, (reshape338,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv210 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_20_mixer_out_proj_q_weight3, transformer_h_20_mixer_out_proj_q_scale3, reshape339), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv208_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv210, lv207_1, transformer_h_20_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv209_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv208_1[1]
            rms_norm171: R.Tensor((1, seq_len, 3072), dtype="float16") = lv208_1[0]
            lv211 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_20_mlp_gate_up_proj_q_weight3, transformer_h_20_mlp_gate_up_proj_q_scale3, rms_norm171), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv53 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv211,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv212 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_20_mlp_down_proj_q_weight3, transformer_h_20_mlp_down_proj_q_scale3, lv53), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv210_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv212, lv209_1, transformer_h_21_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv211_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv210_1[1]
            rms_norm172: R.Tensor((1, seq_len, 3072), dtype="float16") = lv210_1[0]
            lv213 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_21_mixer_qkv_proj_q_weight3, transformer_h_21_mixer_qkv_proj_q_scale3, rms_norm172), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape340 = R.call_tir(cls.reshape4, (lv213,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape341 = R.call_tir(cls.reshape5, (reshape340,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv430 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape341), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape342 = R.call_tir(cls.reshape6, (lv430,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape343 = R.call_tir(cls.reshape7, (reshape342,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv214 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_21_mixer_out_proj_q_weight3, transformer_h_21_mixer_out_proj_q_scale3, reshape343), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv212_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv214, lv211_1, transformer_h_21_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv213_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv212_1[1]
            rms_norm173: R.Tensor((1, seq_len, 3072), dtype="float16") = lv212_1[0]
            lv215 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_21_mlp_gate_up_proj_q_weight3, transformer_h_21_mlp_gate_up_proj_q_scale3, rms_norm173), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv54 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv215,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv216 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_21_mlp_down_proj_q_weight3, transformer_h_21_mlp_down_proj_q_scale3, lv54), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv214_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv216, lv213_1, transformer_h_22_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv215_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv214_1[1]
            rms_norm174: R.Tensor((1, seq_len, 3072), dtype="float16") = lv214_1[0]
            lv217 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_22_mixer_qkv_proj_q_weight3, transformer_h_22_mixer_qkv_proj_q_scale3, rms_norm174), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape344 = R.call_tir(cls.reshape4, (lv217,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape345 = R.call_tir(cls.reshape5, (reshape344,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv435 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape345), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape346 = R.call_tir(cls.reshape6, (lv435,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape347 = R.call_tir(cls.reshape7, (reshape346,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv218 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_22_mixer_out_proj_q_weight3, transformer_h_22_mixer_out_proj_q_scale3, reshape347), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv216_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv218, lv215_1, transformer_h_22_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv217_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv216_1[1]
            rms_norm175: R.Tensor((1, seq_len, 3072), dtype="float16") = lv216_1[0]
            lv219 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_22_mlp_gate_up_proj_q_weight3, transformer_h_22_mlp_gate_up_proj_q_scale3, rms_norm175), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv55 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv219,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv220 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_22_mlp_down_proj_q_weight3, transformer_h_22_mlp_down_proj_q_scale3, lv55), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv218_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv220, lv217_1, transformer_h_23_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv219_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv218_1[1]
            rms_norm176: R.Tensor((1, seq_len, 3072), dtype="float16") = lv218_1[0]
            lv221 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_23_mixer_qkv_proj_q_weight3, transformer_h_23_mixer_qkv_proj_q_scale3, rms_norm176), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape348 = R.call_tir(cls.reshape4, (lv221,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape349 = R.call_tir(cls.reshape5, (reshape348,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv440 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape349), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape350 = R.call_tir(cls.reshape6, (lv440,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape351 = R.call_tir(cls.reshape7, (reshape350,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv222 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_23_mixer_out_proj_q_weight3, transformer_h_23_mixer_out_proj_q_scale3, reshape351), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv220_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv222, lv219_1, transformer_h_23_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv221_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv220_1[1]
            rms_norm177: R.Tensor((1, seq_len, 3072), dtype="float16") = lv220_1[0]
            lv223 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_23_mlp_gate_up_proj_q_weight3, transformer_h_23_mlp_gate_up_proj_q_scale3, rms_norm177), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv56 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv223,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv224 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_23_mlp_down_proj_q_weight3, transformer_h_23_mlp_down_proj_q_scale3, lv56), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv222_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv224, lv221_1, transformer_h_24_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv223_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv222_1[1]
            rms_norm178: R.Tensor((1, seq_len, 3072), dtype="float16") = lv222_1[0]
            lv225 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_24_mixer_qkv_proj_q_weight3, transformer_h_24_mixer_qkv_proj_q_scale3, rms_norm178), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape352 = R.call_tir(cls.reshape4, (lv225,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape353 = R.call_tir(cls.reshape5, (reshape352,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv445 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape353), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape354 = R.call_tir(cls.reshape6, (lv445,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape355 = R.call_tir(cls.reshape7, (reshape354,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv226 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_24_mixer_out_proj_q_weight3, transformer_h_24_mixer_out_proj_q_scale3, reshape355), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv224_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv226, lv223_1, transformer_h_24_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv225_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv224_1[1]
            rms_norm179: R.Tensor((1, seq_len, 3072), dtype="float16") = lv224_1[0]
            lv227 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_24_mlp_gate_up_proj_q_weight3, transformer_h_24_mlp_gate_up_proj_q_scale3, rms_norm179), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv57 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv227,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv228 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_24_mlp_down_proj_q_weight3, transformer_h_24_mlp_down_proj_q_scale3, lv57), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv226_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv228, lv225_1, transformer_h_25_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv227_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv226_1[1]
            rms_norm180: R.Tensor((1, seq_len, 3072), dtype="float16") = lv226_1[0]
            lv229 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_25_mixer_qkv_proj_q_weight3, transformer_h_25_mixer_qkv_proj_q_scale3, rms_norm180), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape356 = R.call_tir(cls.reshape4, (lv229,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape357 = R.call_tir(cls.reshape5, (reshape356,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv450 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape357), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape358 = R.call_tir(cls.reshape6, (lv450,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape359 = R.call_tir(cls.reshape7, (reshape358,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv230 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_25_mixer_out_proj_q_weight3, transformer_h_25_mixer_out_proj_q_scale3, reshape359), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv228_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv230, lv227_1, transformer_h_25_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv229_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv228_1[1]
            rms_norm181: R.Tensor((1, seq_len, 3072), dtype="float16") = lv228_1[0]
            lv231 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_25_mlp_gate_up_proj_q_weight3, transformer_h_25_mlp_gate_up_proj_q_scale3, rms_norm181), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv58 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv231,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv232 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_25_mlp_down_proj_q_weight3, transformer_h_25_mlp_down_proj_q_scale3, lv58), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv230_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv232, lv229_1, transformer_h_26_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv231_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv230_1[1]
            rms_norm182: R.Tensor((1, seq_len, 3072), dtype="float16") = lv230_1[0]
            lv233 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_26_mixer_qkv_proj_q_weight3, transformer_h_26_mixer_qkv_proj_q_scale3, rms_norm182), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape360 = R.call_tir(cls.reshape4, (lv233,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape361 = R.call_tir(cls.reshape5, (reshape360,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv455 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape361), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape362 = R.call_tir(cls.reshape6, (lv455,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape363 = R.call_tir(cls.reshape7, (reshape362,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv234 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_26_mixer_out_proj_q_weight3, transformer_h_26_mixer_out_proj_q_scale3, reshape363), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv232_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv234, lv231_1, transformer_h_26_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv233_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv232_1[1]
            rms_norm183: R.Tensor((1, seq_len, 3072), dtype="float16") = lv232_1[0]
            lv235 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_26_mlp_gate_up_proj_q_weight3, transformer_h_26_mlp_gate_up_proj_q_scale3, rms_norm183), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv59 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv235,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv236 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_26_mlp_down_proj_q_weight3, transformer_h_26_mlp_down_proj_q_scale3, lv59), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv234_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv236, lv233_1, transformer_h_27_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv235_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv234_1[1]
            rms_norm184: R.Tensor((1, seq_len, 3072), dtype="float16") = lv234_1[0]
            lv237 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_27_mixer_qkv_proj_q_weight3, transformer_h_27_mixer_qkv_proj_q_scale3, rms_norm184), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape364 = R.call_tir(cls.reshape4, (lv237,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape365 = R.call_tir(cls.reshape5, (reshape364,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv460 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape365), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape366 = R.call_tir(cls.reshape6, (lv460,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape367 = R.call_tir(cls.reshape7, (reshape366,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv238 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_27_mixer_out_proj_q_weight3, transformer_h_27_mixer_out_proj_q_scale3, reshape367), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv236_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv238, lv235_1, transformer_h_27_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv237_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv236_1[1]
            rms_norm185: R.Tensor((1, seq_len, 3072), dtype="float16") = lv236_1[0]
            lv239 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_27_mlp_gate_up_proj_q_weight3, transformer_h_27_mlp_gate_up_proj_q_scale3, rms_norm185), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv60 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv239,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv240 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_27_mlp_down_proj_q_weight3, transformer_h_27_mlp_down_proj_q_scale3, lv60), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv238_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv240, lv237_1, transformer_h_28_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv239_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv238_1[1]
            rms_norm186: R.Tensor((1, seq_len, 3072), dtype="float16") = lv238_1[0]
            lv241 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_28_mixer_qkv_proj_q_weight3, transformer_h_28_mixer_qkv_proj_q_scale3, rms_norm186), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape368 = R.call_tir(cls.reshape4, (lv241,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape369 = R.call_tir(cls.reshape5, (reshape368,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv465 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape369), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape370 = R.call_tir(cls.reshape6, (lv465,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape371 = R.call_tir(cls.reshape7, (reshape370,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv242 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_28_mixer_out_proj_q_weight3, transformer_h_28_mixer_out_proj_q_scale3, reshape371), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv240_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv242, lv239_1, transformer_h_28_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv241_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv240_1[1]
            rms_norm187: R.Tensor((1, seq_len, 3072), dtype="float16") = lv240_1[0]
            lv243 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_28_mlp_gate_up_proj_q_weight3, transformer_h_28_mlp_gate_up_proj_q_scale3, rms_norm187), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv61 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv243,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv244 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_28_mlp_down_proj_q_weight3, transformer_h_28_mlp_down_proj_q_scale3, lv61), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv242_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv244, lv241_1, transformer_h_29_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv243_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv242_1[1]
            rms_norm188: R.Tensor((1, seq_len, 3072), dtype="float16") = lv242_1[0]
            lv245 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_29_mixer_qkv_proj_q_weight3, transformer_h_29_mixer_qkv_proj_q_scale3, rms_norm188), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape372 = R.call_tir(cls.reshape4, (lv245,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape373 = R.call_tir(cls.reshape5, (reshape372,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv470 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape373), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape374 = R.call_tir(cls.reshape6, (lv470,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape375 = R.call_tir(cls.reshape7, (reshape374,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv246 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_29_mixer_out_proj_q_weight3, transformer_h_29_mixer_out_proj_q_scale3, reshape375), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv244_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv246, lv243_1, transformer_h_29_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv245_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv244_1[1]
            rms_norm189: R.Tensor((1, seq_len, 3072), dtype="float16") = lv244_1[0]
            lv247 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_29_mlp_gate_up_proj_q_weight3, transformer_h_29_mlp_gate_up_proj_q_scale3, rms_norm189), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv62 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv247,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv248 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_29_mlp_down_proj_q_weight3, transformer_h_29_mlp_down_proj_q_scale3, lv62), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv246_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv248, lv245_1, transformer_h_30_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv247_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv246_1[1]
            rms_norm190: R.Tensor((1, seq_len, 3072), dtype="float16") = lv246_1[0]
            lv249 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_30_mixer_qkv_proj_q_weight3, transformer_h_30_mixer_qkv_proj_q_scale3, rms_norm190), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape376 = R.call_tir(cls.reshape4, (lv249,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape377 = R.call_tir(cls.reshape5, (reshape376,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv475 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape377), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape378 = R.call_tir(cls.reshape6, (lv475,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape379 = R.call_tir(cls.reshape7, (reshape378,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv250 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_30_mixer_out_proj_q_weight3, transformer_h_30_mixer_out_proj_q_scale3, reshape379), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv248_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv250, lv247_1, transformer_h_30_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv249_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv248_1[1]
            rms_norm191: R.Tensor((1, seq_len, 3072), dtype="float16") = lv248_1[0]
            lv251 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_30_mlp_gate_up_proj_q_weight3, transformer_h_30_mlp_gate_up_proj_q_scale3, rms_norm191), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv63 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv251,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv252 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_30_mlp_down_proj_q_weight3, transformer_h_30_mlp_down_proj_q_scale3, lv63), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv250_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv252, lv249_1, transformer_h_31_ln_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv251_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv250_1[1]
            rms_norm192: R.Tensor((1, seq_len, 3072), dtype="float16") = lv250_1[0]
            lv253 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_31_mixer_qkv_proj_q_weight3, transformer_h_31_mixer_qkv_proj_q_scale3, rms_norm192), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape380 = R.call_tir(cls.reshape4, (lv253,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape381 = R.call_tir(cls.reshape5, (reshape380,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv480 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape381), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape382 = R.call_tir(cls.reshape6, (lv480,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape383 = R.call_tir(cls.reshape7, (reshape382,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv254 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_31_mixer_out_proj_q_weight3, transformer_h_31_mixer_out_proj_q_scale3, reshape383), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv252_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv254, lv251_1, transformer_h_31_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv253_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv252_1[1]
            rms_norm193: R.Tensor((1, seq_len, 3072), dtype="float16") = lv252_1[0]
            lv255 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_31_mlp_gate_up_proj_q_weight3, transformer_h_31_mlp_gate_up_proj_q_scale3, rms_norm193), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv64 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv255,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv256 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_31_mlp_down_proj_q_weight3, transformer_h_31_mlp_down_proj_q_scale3, lv64), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv254_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv256, lv253_1, transformer_norm_weight3), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            rms_norm194: R.Tensor((1, seq_len, 3072), dtype="float16") = lv254_1[0]
            take1 = R.call_tir(cls.take, (rms_norm194, logit_positions), out_sinfo=R.Tensor((1, batch_size, 3072), dtype="float16"))
            lv257 = R.call_tir(cls.fused_dequantize5_fused_NT_matmul9_cast1, (lm_head_q_weight3, lm_head_q_scale3, take1), out_sinfo=R.Tensor((1, batch_size, vocab_size), dtype="float32"))
            gv3: R.Tuple(R.Tensor((1, batch_size, vocab_size), dtype="float32"), R.Object) = lv257, paged_kv_cache
            R.output(gv3)
        return gv3

    @R.function
    def batch_verify(input_embeds: R.Tensor((1, "seq_len", 3072), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32064, 384), dtype="uint32"), R.Tensor((32064, 96), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor(("vocab_size", 384), dtype="uint32"), R.Tensor(("vocab_size", 96), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", "vocab_size"), dtype="float32"), R.Object):
        seq_len = T.int64()
        vocab_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            transformer_h_0_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[2]
            transformer_h_0_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[3]
            transformer_h_0_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[4]
            transformer_h_0_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[5]
            transformer_h_0_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[6]
            transformer_h_0_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[7]
            transformer_h_0_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[8]
            transformer_h_0_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[9]
            transformer_h_0_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[10]
            transformer_h_0_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[11]
            transformer_h_1_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[12]
            transformer_h_1_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[13]
            transformer_h_1_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[14]
            transformer_h_1_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[15]
            transformer_h_1_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[16]
            transformer_h_1_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[17]
            transformer_h_1_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[18]
            transformer_h_1_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[19]
            transformer_h_1_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[20]
            transformer_h_1_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[21]
            transformer_h_2_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[22]
            transformer_h_2_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[23]
            transformer_h_2_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[24]
            transformer_h_2_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[25]
            transformer_h_2_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[26]
            transformer_h_2_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[27]
            transformer_h_2_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[28]
            transformer_h_2_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[29]
            transformer_h_2_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[30]
            transformer_h_2_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[31]
            transformer_h_3_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[32]
            transformer_h_3_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[33]
            transformer_h_3_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[34]
            transformer_h_3_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[35]
            transformer_h_3_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[36]
            transformer_h_3_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[37]
            transformer_h_3_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[38]
            transformer_h_3_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[39]
            transformer_h_3_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[40]
            transformer_h_3_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[41]
            transformer_h_4_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[42]
            transformer_h_4_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[43]
            transformer_h_4_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[44]
            transformer_h_4_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[45]
            transformer_h_4_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[46]
            transformer_h_4_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[47]
            transformer_h_4_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[48]
            transformer_h_4_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[49]
            transformer_h_4_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[50]
            transformer_h_4_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[51]
            transformer_h_5_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[52]
            transformer_h_5_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[53]
            transformer_h_5_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[54]
            transformer_h_5_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[55]
            transformer_h_5_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[56]
            transformer_h_5_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[57]
            transformer_h_5_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[58]
            transformer_h_5_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[59]
            transformer_h_5_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[60]
            transformer_h_5_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[61]
            transformer_h_6_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[62]
            transformer_h_6_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[63]
            transformer_h_6_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[64]
            transformer_h_6_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[65]
            transformer_h_6_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[66]
            transformer_h_6_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[67]
            transformer_h_6_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[68]
            transformer_h_6_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[69]
            transformer_h_6_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[70]
            transformer_h_6_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[71]
            transformer_h_7_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[72]
            transformer_h_7_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[73]
            transformer_h_7_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[74]
            transformer_h_7_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[75]
            transformer_h_7_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[76]
            transformer_h_7_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[77]
            transformer_h_7_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[78]
            transformer_h_7_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[79]
            transformer_h_7_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[80]
            transformer_h_7_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[81]
            transformer_h_8_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[82]
            transformer_h_8_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[83]
            transformer_h_8_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[84]
            transformer_h_8_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[85]
            transformer_h_8_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[86]
            transformer_h_8_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[87]
            transformer_h_8_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[88]
            transformer_h_8_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[89]
            transformer_h_8_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[90]
            transformer_h_8_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[91]
            transformer_h_9_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[92]
            transformer_h_9_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[93]
            transformer_h_9_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[94]
            transformer_h_9_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[95]
            transformer_h_9_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[96]
            transformer_h_9_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[97]
            transformer_h_9_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[98]
            transformer_h_9_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[99]
            transformer_h_9_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[100]
            transformer_h_9_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[101]
            transformer_h_10_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[102]
            transformer_h_10_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[103]
            transformer_h_10_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[104]
            transformer_h_10_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[105]
            transformer_h_10_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[106]
            transformer_h_10_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[107]
            transformer_h_10_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[108]
            transformer_h_10_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[109]
            transformer_h_10_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[110]
            transformer_h_10_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[111]
            transformer_h_11_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[112]
            transformer_h_11_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[113]
            transformer_h_11_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[114]
            transformer_h_11_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[115]
            transformer_h_11_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[116]
            transformer_h_11_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[117]
            transformer_h_11_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[118]
            transformer_h_11_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[119]
            transformer_h_11_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[120]
            transformer_h_11_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[121]
            transformer_h_12_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[122]
            transformer_h_12_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[123]
            transformer_h_12_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[124]
            transformer_h_12_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[125]
            transformer_h_12_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[126]
            transformer_h_12_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[127]
            transformer_h_12_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[128]
            transformer_h_12_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[129]
            transformer_h_12_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[130]
            transformer_h_12_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[131]
            transformer_h_13_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[132]
            transformer_h_13_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[133]
            transformer_h_13_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[134]
            transformer_h_13_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[135]
            transformer_h_13_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[136]
            transformer_h_13_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[137]
            transformer_h_13_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[138]
            transformer_h_13_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[139]
            transformer_h_13_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[140]
            transformer_h_13_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[141]
            transformer_h_14_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[142]
            transformer_h_14_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[143]
            transformer_h_14_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[144]
            transformer_h_14_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[145]
            transformer_h_14_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[146]
            transformer_h_14_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[147]
            transformer_h_14_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[148]
            transformer_h_14_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[149]
            transformer_h_14_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[150]
            transformer_h_14_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[151]
            transformer_h_15_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[152]
            transformer_h_15_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[153]
            transformer_h_15_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[154]
            transformer_h_15_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[155]
            transformer_h_15_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[156]
            transformer_h_15_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[157]
            transformer_h_15_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[158]
            transformer_h_15_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[159]
            transformer_h_15_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[160]
            transformer_h_15_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[161]
            transformer_h_16_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[162]
            transformer_h_16_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[163]
            transformer_h_16_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[164]
            transformer_h_16_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[165]
            transformer_h_16_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[166]
            transformer_h_16_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[167]
            transformer_h_16_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[168]
            transformer_h_16_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[169]
            transformer_h_16_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[170]
            transformer_h_16_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[171]
            transformer_h_17_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[172]
            transformer_h_17_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[173]
            transformer_h_17_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[174]
            transformer_h_17_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[175]
            transformer_h_17_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[176]
            transformer_h_17_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[177]
            transformer_h_17_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[178]
            transformer_h_17_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[179]
            transformer_h_17_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[180]
            transformer_h_17_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[181]
            transformer_h_18_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[182]
            transformer_h_18_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[183]
            transformer_h_18_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[184]
            transformer_h_18_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[185]
            transformer_h_18_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[186]
            transformer_h_18_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[187]
            transformer_h_18_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[188]
            transformer_h_18_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[189]
            transformer_h_18_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[190]
            transformer_h_18_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[191]
            transformer_h_19_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[192]
            transformer_h_19_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[193]
            transformer_h_19_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[194]
            transformer_h_19_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[195]
            transformer_h_19_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[196]
            transformer_h_19_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[197]
            transformer_h_19_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[198]
            transformer_h_19_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[199]
            transformer_h_19_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[200]
            transformer_h_19_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[201]
            transformer_h_20_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[202]
            transformer_h_20_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[203]
            transformer_h_20_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[204]
            transformer_h_20_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[205]
            transformer_h_20_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[206]
            transformer_h_20_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[207]
            transformer_h_20_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[208]
            transformer_h_20_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[209]
            transformer_h_20_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[210]
            transformer_h_20_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[211]
            transformer_h_21_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[212]
            transformer_h_21_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[213]
            transformer_h_21_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[214]
            transformer_h_21_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[215]
            transformer_h_21_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[216]
            transformer_h_21_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[217]
            transformer_h_21_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[218]
            transformer_h_21_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[219]
            transformer_h_21_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[220]
            transformer_h_21_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[221]
            transformer_h_22_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[222]
            transformer_h_22_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[223]
            transformer_h_22_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[224]
            transformer_h_22_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[225]
            transformer_h_22_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[226]
            transformer_h_22_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[227]
            transformer_h_22_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[228]
            transformer_h_22_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[229]
            transformer_h_22_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[230]
            transformer_h_22_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[231]
            transformer_h_23_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[232]
            transformer_h_23_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[233]
            transformer_h_23_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[234]
            transformer_h_23_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[235]
            transformer_h_23_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[236]
            transformer_h_23_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[237]
            transformer_h_23_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[238]
            transformer_h_23_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[239]
            transformer_h_23_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[240]
            transformer_h_23_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[241]
            transformer_h_24_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[242]
            transformer_h_24_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[243]
            transformer_h_24_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[244]
            transformer_h_24_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[245]
            transformer_h_24_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[246]
            transformer_h_24_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[247]
            transformer_h_24_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[248]
            transformer_h_24_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[249]
            transformer_h_24_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[250]
            transformer_h_24_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[251]
            transformer_h_25_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[252]
            transformer_h_25_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[253]
            transformer_h_25_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[254]
            transformer_h_25_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[255]
            transformer_h_25_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[256]
            transformer_h_25_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[257]
            transformer_h_25_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[258]
            transformer_h_25_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[259]
            transformer_h_25_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[260]
            transformer_h_25_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[261]
            transformer_h_26_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[262]
            transformer_h_26_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[263]
            transformer_h_26_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[264]
            transformer_h_26_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[265]
            transformer_h_26_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[266]
            transformer_h_26_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[267]
            transformer_h_26_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[268]
            transformer_h_26_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[269]
            transformer_h_26_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[270]
            transformer_h_26_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[271]
            transformer_h_27_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[272]
            transformer_h_27_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[273]
            transformer_h_27_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[274]
            transformer_h_27_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[275]
            transformer_h_27_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[276]
            transformer_h_27_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[277]
            transformer_h_27_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[278]
            transformer_h_27_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[279]
            transformer_h_27_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[280]
            transformer_h_27_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[281]
            transformer_h_28_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[282]
            transformer_h_28_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[283]
            transformer_h_28_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[284]
            transformer_h_28_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[285]
            transformer_h_28_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[286]
            transformer_h_28_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[287]
            transformer_h_28_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[288]
            transformer_h_28_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[289]
            transformer_h_28_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[290]
            transformer_h_28_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[291]
            transformer_h_29_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[292]
            transformer_h_29_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[293]
            transformer_h_29_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[294]
            transformer_h_29_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[295]
            transformer_h_29_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[296]
            transformer_h_29_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[297]
            transformer_h_29_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[298]
            transformer_h_29_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[299]
            transformer_h_29_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[300]
            transformer_h_29_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[301]
            transformer_h_30_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[302]
            transformer_h_30_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[303]
            transformer_h_30_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[304]
            transformer_h_30_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[305]
            transformer_h_30_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[306]
            transformer_h_30_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[307]
            transformer_h_30_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[308]
            transformer_h_30_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[309]
            transformer_h_30_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[310]
            transformer_h_30_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[311]
            transformer_h_31_ln_weight5: R.Tensor((3072,), dtype="float16") = packed_params[312]
            transformer_h_31_mixer_qkv_proj_q_weight5: R.Tensor((9216, 384), dtype="uint32") = packed_params[313]
            transformer_h_31_mixer_qkv_proj_q_scale5: R.Tensor((9216, 96), dtype="float16") = packed_params[314]
            transformer_h_31_mixer_out_proj_q_weight5: R.Tensor((3072, 384), dtype="uint32") = packed_params[315]
            transformer_h_31_mixer_out_proj_q_scale5: R.Tensor((3072, 96), dtype="float16") = packed_params[316]
            transformer_h_31_mlp_gate_up_proj_q_weight5: R.Tensor((16384, 384), dtype="uint32") = packed_params[317]
            transformer_h_31_mlp_gate_up_proj_q_scale5: R.Tensor((16384, 96), dtype="float16") = packed_params[318]
            transformer_h_31_mlp_down_proj_q_weight5: R.Tensor((3072, 1024), dtype="uint32") = packed_params[319]
            transformer_h_31_mlp_down_proj_q_scale5: R.Tensor((3072, 256), dtype="float16") = packed_params[320]
            transformer_h_31_post_attention_layernorm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[321]
            transformer_norm_weight5: R.Tensor((3072,), dtype="float16") = packed_params[322]
            lm_head_q_weight5: R.Tensor((vocab_size, 384), dtype="uint32") = packed_params[323]
            lm_head_q_scale5: R.Tensor((vocab_size, 96), dtype="float16") = packed_params[324]
            rms_norm260 = R.call_tir(cls.rms_norm1, (input_embeds, transformer_h_0_ln_weight5), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv258 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_0_mixer_qkv_proj_q_weight5, transformer_h_0_mixer_qkv_proj_q_scale5, rms_norm260), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape512 = R.call_tir(cls.reshape4, (lv258,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape513 = R.call_tir(cls.reshape5, (reshape512,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv647 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape513), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape514 = R.call_tir(cls.reshape6, (lv647,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape515 = R.call_tir(cls.reshape7, (reshape514,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv259 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_0_mixer_out_proj_q_weight5, transformer_h_0_mixer_out_proj_q_scale5, reshape515), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv256 = R.call_tir(cls.fuse_add_norm_prefill, (lv259, input_embeds, transformer_h_0_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv257: R.Tensor((1, seq_len, 3072), dtype="float16") = lv256[1]
            rms_norm261: R.Tensor((1, seq_len, 3072), dtype="float16") = lv256[0]
            lv260 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_0_mlp_gate_up_proj_q_weight5, transformer_h_0_mlp_gate_up_proj_q_scale5, rms_norm261), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv66 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv260,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv261 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_0_mlp_down_proj_q_weight5, transformer_h_0_mlp_down_proj_q_scale5, lv66), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv258_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv261, lv257, transformer_h_1_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv259_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv258_1[1]
            rms_norm262: R.Tensor((1, seq_len, 3072), dtype="float16") = lv258_1[0]
            lv262 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_1_mixer_qkv_proj_q_weight5, transformer_h_1_mixer_qkv_proj_q_scale5, rms_norm262), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape516 = R.call_tir(cls.reshape4, (lv262,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape517 = R.call_tir(cls.reshape5, (reshape516,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv652 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape517), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape518 = R.call_tir(cls.reshape6, (lv652,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape519 = R.call_tir(cls.reshape7, (reshape518,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv263 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_1_mixer_out_proj_q_weight5, transformer_h_1_mixer_out_proj_q_scale5, reshape519), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv260_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv263, lv259_1, transformer_h_1_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv261_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv260_1[1]
            rms_norm263: R.Tensor((1, seq_len, 3072), dtype="float16") = lv260_1[0]
            lv264 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_1_mlp_gate_up_proj_q_weight5, transformer_h_1_mlp_gate_up_proj_q_scale5, rms_norm263), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv67 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv264,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv265 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_1_mlp_down_proj_q_weight5, transformer_h_1_mlp_down_proj_q_scale5, lv67), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv262_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv265, lv261_1, transformer_h_2_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv263_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv262_1[1]
            rms_norm264: R.Tensor((1, seq_len, 3072), dtype="float16") = lv262_1[0]
            lv266 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_2_mixer_qkv_proj_q_weight5, transformer_h_2_mixer_qkv_proj_q_scale5, rms_norm264), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape520 = R.call_tir(cls.reshape4, (lv266,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape521 = R.call_tir(cls.reshape5, (reshape520,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv657 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape521), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape522 = R.call_tir(cls.reshape6, (lv657,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape523 = R.call_tir(cls.reshape7, (reshape522,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv267 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_2_mixer_out_proj_q_weight5, transformer_h_2_mixer_out_proj_q_scale5, reshape523), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv264_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv267, lv263_1, transformer_h_2_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv265_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv264_1[1]
            rms_norm265: R.Tensor((1, seq_len, 3072), dtype="float16") = lv264_1[0]
            lv268 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_2_mlp_gate_up_proj_q_weight5, transformer_h_2_mlp_gate_up_proj_q_scale5, rms_norm265), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv68 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv268,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv269 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_2_mlp_down_proj_q_weight5, transformer_h_2_mlp_down_proj_q_scale5, lv68), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv266_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv269, lv265_1, transformer_h_3_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv267_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv266_1[1]
            rms_norm266: R.Tensor((1, seq_len, 3072), dtype="float16") = lv266_1[0]
            lv270 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_3_mixer_qkv_proj_q_weight5, transformer_h_3_mixer_qkv_proj_q_scale5, rms_norm266), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape524 = R.call_tir(cls.reshape4, (lv270,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape525 = R.call_tir(cls.reshape5, (reshape524,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv662 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape525), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape526 = R.call_tir(cls.reshape6, (lv662,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape527 = R.call_tir(cls.reshape7, (reshape526,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv271 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_3_mixer_out_proj_q_weight5, transformer_h_3_mixer_out_proj_q_scale5, reshape527), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv268_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv271, lv267_1, transformer_h_3_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv269_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv268_1[1]
            rms_norm267: R.Tensor((1, seq_len, 3072), dtype="float16") = lv268_1[0]
            lv272 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_3_mlp_gate_up_proj_q_weight5, transformer_h_3_mlp_gate_up_proj_q_scale5, rms_norm267), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv69 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv272,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv273 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_3_mlp_down_proj_q_weight5, transformer_h_3_mlp_down_proj_q_scale5, lv69), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv270_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv273, lv269_1, transformer_h_4_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv271_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv270_1[1]
            rms_norm268: R.Tensor((1, seq_len, 3072), dtype="float16") = lv270_1[0]
            lv274 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_4_mixer_qkv_proj_q_weight5, transformer_h_4_mixer_qkv_proj_q_scale5, rms_norm268), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape528 = R.call_tir(cls.reshape4, (lv274,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape529 = R.call_tir(cls.reshape5, (reshape528,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv667 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape529), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape530 = R.call_tir(cls.reshape6, (lv667,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape531 = R.call_tir(cls.reshape7, (reshape530,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv275 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_4_mixer_out_proj_q_weight5, transformer_h_4_mixer_out_proj_q_scale5, reshape531), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv272_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv275, lv271_1, transformer_h_4_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv273_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv272_1[1]
            rms_norm269: R.Tensor((1, seq_len, 3072), dtype="float16") = lv272_1[0]
            lv276 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_4_mlp_gate_up_proj_q_weight5, transformer_h_4_mlp_gate_up_proj_q_scale5, rms_norm269), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv70 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv276,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv277 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_4_mlp_down_proj_q_weight5, transformer_h_4_mlp_down_proj_q_scale5, lv70), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv274_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv277, lv273_1, transformer_h_5_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv275_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv274_1[1]
            rms_norm270: R.Tensor((1, seq_len, 3072), dtype="float16") = lv274_1[0]
            lv278 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_5_mixer_qkv_proj_q_weight5, transformer_h_5_mixer_qkv_proj_q_scale5, rms_norm270), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape532 = R.call_tir(cls.reshape4, (lv278,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape533 = R.call_tir(cls.reshape5, (reshape532,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv672 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape533), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape534 = R.call_tir(cls.reshape6, (lv672,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape535 = R.call_tir(cls.reshape7, (reshape534,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv279 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_5_mixer_out_proj_q_weight5, transformer_h_5_mixer_out_proj_q_scale5, reshape535), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv276_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv279, lv275_1, transformer_h_5_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv277_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv276_1[1]
            rms_norm271: R.Tensor((1, seq_len, 3072), dtype="float16") = lv276_1[0]
            lv280 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_5_mlp_gate_up_proj_q_weight5, transformer_h_5_mlp_gate_up_proj_q_scale5, rms_norm271), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv71 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv280,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv281 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_5_mlp_down_proj_q_weight5, transformer_h_5_mlp_down_proj_q_scale5, lv71), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv278_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv281, lv277_1, transformer_h_6_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv279_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv278_1[1]
            rms_norm272: R.Tensor((1, seq_len, 3072), dtype="float16") = lv278_1[0]
            lv282 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_6_mixer_qkv_proj_q_weight5, transformer_h_6_mixer_qkv_proj_q_scale5, rms_norm272), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape536 = R.call_tir(cls.reshape4, (lv282,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape537 = R.call_tir(cls.reshape5, (reshape536,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv677 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape537), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape538 = R.call_tir(cls.reshape6, (lv677,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape539 = R.call_tir(cls.reshape7, (reshape538,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv283 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_6_mixer_out_proj_q_weight5, transformer_h_6_mixer_out_proj_q_scale5, reshape539), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv280_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv283, lv279_1, transformer_h_6_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv281_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv280_1[1]
            rms_norm273: R.Tensor((1, seq_len, 3072), dtype="float16") = lv280_1[0]
            lv284 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_6_mlp_gate_up_proj_q_weight5, transformer_h_6_mlp_gate_up_proj_q_scale5, rms_norm273), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv72 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv284,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv285 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_6_mlp_down_proj_q_weight5, transformer_h_6_mlp_down_proj_q_scale5, lv72), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv282_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv285, lv281_1, transformer_h_7_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv283_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv282_1[1]
            rms_norm274: R.Tensor((1, seq_len, 3072), dtype="float16") = lv282_1[0]
            lv286 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_7_mixer_qkv_proj_q_weight5, transformer_h_7_mixer_qkv_proj_q_scale5, rms_norm274), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape540 = R.call_tir(cls.reshape4, (lv286,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape541 = R.call_tir(cls.reshape5, (reshape540,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv682 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape541), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape542 = R.call_tir(cls.reshape6, (lv682,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape543 = R.call_tir(cls.reshape7, (reshape542,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv287 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_7_mixer_out_proj_q_weight5, transformer_h_7_mixer_out_proj_q_scale5, reshape543), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv284_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv287, lv283_1, transformer_h_7_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv285_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv284_1[1]
            rms_norm275: R.Tensor((1, seq_len, 3072), dtype="float16") = lv284_1[0]
            lv288 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_7_mlp_gate_up_proj_q_weight5, transformer_h_7_mlp_gate_up_proj_q_scale5, rms_norm275), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv73 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv288,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv289 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_7_mlp_down_proj_q_weight5, transformer_h_7_mlp_down_proj_q_scale5, lv73), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv286_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv289, lv285_1, transformer_h_8_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv287_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv286_1[1]
            rms_norm276: R.Tensor((1, seq_len, 3072), dtype="float16") = lv286_1[0]
            lv290 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_8_mixer_qkv_proj_q_weight5, transformer_h_8_mixer_qkv_proj_q_scale5, rms_norm276), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape544 = R.call_tir(cls.reshape4, (lv290,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape545 = R.call_tir(cls.reshape5, (reshape544,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv687 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape545), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape546 = R.call_tir(cls.reshape6, (lv687,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape547 = R.call_tir(cls.reshape7, (reshape546,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv291 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_8_mixer_out_proj_q_weight5, transformer_h_8_mixer_out_proj_q_scale5, reshape547), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv288_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv291, lv287_1, transformer_h_8_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv289_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv288_1[1]
            rms_norm277: R.Tensor((1, seq_len, 3072), dtype="float16") = lv288_1[0]
            lv292 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_8_mlp_gate_up_proj_q_weight5, transformer_h_8_mlp_gate_up_proj_q_scale5, rms_norm277), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv74 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv292,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv293 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_8_mlp_down_proj_q_weight5, transformer_h_8_mlp_down_proj_q_scale5, lv74), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv290_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv293, lv289_1, transformer_h_9_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv291_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv290_1[1]
            rms_norm278: R.Tensor((1, seq_len, 3072), dtype="float16") = lv290_1[0]
            lv294 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_9_mixer_qkv_proj_q_weight5, transformer_h_9_mixer_qkv_proj_q_scale5, rms_norm278), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape548 = R.call_tir(cls.reshape4, (lv294,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape549 = R.call_tir(cls.reshape5, (reshape548,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv692 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape549), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape550 = R.call_tir(cls.reshape6, (lv692,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape551 = R.call_tir(cls.reshape7, (reshape550,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv295 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_9_mixer_out_proj_q_weight5, transformer_h_9_mixer_out_proj_q_scale5, reshape551), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv292_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv295, lv291_1, transformer_h_9_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv293_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv292_1[1]
            rms_norm279: R.Tensor((1, seq_len, 3072), dtype="float16") = lv292_1[0]
            lv296 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_9_mlp_gate_up_proj_q_weight5, transformer_h_9_mlp_gate_up_proj_q_scale5, rms_norm279), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv75 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv296,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv297 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_9_mlp_down_proj_q_weight5, transformer_h_9_mlp_down_proj_q_scale5, lv75), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv294_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv297, lv293_1, transformer_h_10_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv295_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv294_1[1]
            rms_norm280: R.Tensor((1, seq_len, 3072), dtype="float16") = lv294_1[0]
            lv298 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_10_mixer_qkv_proj_q_weight5, transformer_h_10_mixer_qkv_proj_q_scale5, rms_norm280), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape552 = R.call_tir(cls.reshape4, (lv298,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape553 = R.call_tir(cls.reshape5, (reshape552,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv697 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape553), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape554 = R.call_tir(cls.reshape6, (lv697,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape555 = R.call_tir(cls.reshape7, (reshape554,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv299 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_10_mixer_out_proj_q_weight5, transformer_h_10_mixer_out_proj_q_scale5, reshape555), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv296_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv299, lv295_1, transformer_h_10_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv297_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv296_1[1]
            rms_norm281: R.Tensor((1, seq_len, 3072), dtype="float16") = lv296_1[0]
            lv300 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_10_mlp_gate_up_proj_q_weight5, transformer_h_10_mlp_gate_up_proj_q_scale5, rms_norm281), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv76 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv300,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv301 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_10_mlp_down_proj_q_weight5, transformer_h_10_mlp_down_proj_q_scale5, lv76), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv298_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv301, lv297_1, transformer_h_11_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv299_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv298_1[1]
            rms_norm282: R.Tensor((1, seq_len, 3072), dtype="float16") = lv298_1[0]
            lv302 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_11_mixer_qkv_proj_q_weight5, transformer_h_11_mixer_qkv_proj_q_scale5, rms_norm282), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape556 = R.call_tir(cls.reshape4, (lv302,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape557 = R.call_tir(cls.reshape5, (reshape556,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv702 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape557), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape558 = R.call_tir(cls.reshape6, (lv702,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape559 = R.call_tir(cls.reshape7, (reshape558,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv303 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_11_mixer_out_proj_q_weight5, transformer_h_11_mixer_out_proj_q_scale5, reshape559), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv300_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv303, lv299_1, transformer_h_11_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv301_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv300_1[1]
            rms_norm283: R.Tensor((1, seq_len, 3072), dtype="float16") = lv300_1[0]
            lv304 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_11_mlp_gate_up_proj_q_weight5, transformer_h_11_mlp_gate_up_proj_q_scale5, rms_norm283), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv77 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv304,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv305 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_11_mlp_down_proj_q_weight5, transformer_h_11_mlp_down_proj_q_scale5, lv77), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv302_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv305, lv301_1, transformer_h_12_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv303_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv302_1[1]
            rms_norm284: R.Tensor((1, seq_len, 3072), dtype="float16") = lv302_1[0]
            lv306 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_12_mixer_qkv_proj_q_weight5, transformer_h_12_mixer_qkv_proj_q_scale5, rms_norm284), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape560 = R.call_tir(cls.reshape4, (lv306,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape561 = R.call_tir(cls.reshape5, (reshape560,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv707 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape561), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape562 = R.call_tir(cls.reshape6, (lv707,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape563 = R.call_tir(cls.reshape7, (reshape562,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv307 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_12_mixer_out_proj_q_weight5, transformer_h_12_mixer_out_proj_q_scale5, reshape563), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv304_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv307, lv303_1, transformer_h_12_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv305_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv304_1[1]
            rms_norm285: R.Tensor((1, seq_len, 3072), dtype="float16") = lv304_1[0]
            lv308 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_12_mlp_gate_up_proj_q_weight5, transformer_h_12_mlp_gate_up_proj_q_scale5, rms_norm285), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv78 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv308,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv309 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_12_mlp_down_proj_q_weight5, transformer_h_12_mlp_down_proj_q_scale5, lv78), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv306_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv309, lv305_1, transformer_h_13_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv307_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv306_1[1]
            rms_norm286: R.Tensor((1, seq_len, 3072), dtype="float16") = lv306_1[0]
            lv310 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_13_mixer_qkv_proj_q_weight5, transformer_h_13_mixer_qkv_proj_q_scale5, rms_norm286), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape564 = R.call_tir(cls.reshape4, (lv310,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape565 = R.call_tir(cls.reshape5, (reshape564,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv712 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape565), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape566 = R.call_tir(cls.reshape6, (lv712,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape567 = R.call_tir(cls.reshape7, (reshape566,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv311 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_13_mixer_out_proj_q_weight5, transformer_h_13_mixer_out_proj_q_scale5, reshape567), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv308_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv311, lv307_1, transformer_h_13_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv309_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv308_1[1]
            rms_norm287: R.Tensor((1, seq_len, 3072), dtype="float16") = lv308_1[0]
            lv312 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_13_mlp_gate_up_proj_q_weight5, transformer_h_13_mlp_gate_up_proj_q_scale5, rms_norm287), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv79 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv312,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv313 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_13_mlp_down_proj_q_weight5, transformer_h_13_mlp_down_proj_q_scale5, lv79), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv310_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv313, lv309_1, transformer_h_14_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv311_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv310_1[1]
            rms_norm288: R.Tensor((1, seq_len, 3072), dtype="float16") = lv310_1[0]
            lv314 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_14_mixer_qkv_proj_q_weight5, transformer_h_14_mixer_qkv_proj_q_scale5, rms_norm288), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape568 = R.call_tir(cls.reshape4, (lv314,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape569 = R.call_tir(cls.reshape5, (reshape568,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv717 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape569), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape570 = R.call_tir(cls.reshape6, (lv717,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape571 = R.call_tir(cls.reshape7, (reshape570,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv315 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_14_mixer_out_proj_q_weight5, transformer_h_14_mixer_out_proj_q_scale5, reshape571), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv312_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv315, lv311_1, transformer_h_14_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv313_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv312_1[1]
            rms_norm289: R.Tensor((1, seq_len, 3072), dtype="float16") = lv312_1[0]
            lv316 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_14_mlp_gate_up_proj_q_weight5, transformer_h_14_mlp_gate_up_proj_q_scale5, rms_norm289), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv80 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv316,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv317 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_14_mlp_down_proj_q_weight5, transformer_h_14_mlp_down_proj_q_scale5, lv80), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv314_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv317, lv313_1, transformer_h_15_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv315_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv314_1[1]
            rms_norm290: R.Tensor((1, seq_len, 3072), dtype="float16") = lv314_1[0]
            lv318 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_15_mixer_qkv_proj_q_weight5, transformer_h_15_mixer_qkv_proj_q_scale5, rms_norm290), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape572 = R.call_tir(cls.reshape4, (lv318,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape573 = R.call_tir(cls.reshape5, (reshape572,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv722 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape573), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape574 = R.call_tir(cls.reshape6, (lv722,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape575 = R.call_tir(cls.reshape7, (reshape574,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv319 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_15_mixer_out_proj_q_weight5, transformer_h_15_mixer_out_proj_q_scale5, reshape575), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv316_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv319, lv315_1, transformer_h_15_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv317_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv316_1[1]
            rms_norm291: R.Tensor((1, seq_len, 3072), dtype="float16") = lv316_1[0]
            lv320 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_15_mlp_gate_up_proj_q_weight5, transformer_h_15_mlp_gate_up_proj_q_scale5, rms_norm291), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv81 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv320,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv321 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_15_mlp_down_proj_q_weight5, transformer_h_15_mlp_down_proj_q_scale5, lv81), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv318_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv321, lv317_1, transformer_h_16_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv319_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv318_1[1]
            rms_norm292: R.Tensor((1, seq_len, 3072), dtype="float16") = lv318_1[0]
            lv322 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_16_mixer_qkv_proj_q_weight5, transformer_h_16_mixer_qkv_proj_q_scale5, rms_norm292), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape576 = R.call_tir(cls.reshape4, (lv322,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape577 = R.call_tir(cls.reshape5, (reshape576,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv727 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape577), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape578 = R.call_tir(cls.reshape6, (lv727,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape579 = R.call_tir(cls.reshape7, (reshape578,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv323 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_16_mixer_out_proj_q_weight5, transformer_h_16_mixer_out_proj_q_scale5, reshape579), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv320_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv323, lv319_1, transformer_h_16_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv321_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv320_1[1]
            rms_norm293: R.Tensor((1, seq_len, 3072), dtype="float16") = lv320_1[0]
            lv324 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_16_mlp_gate_up_proj_q_weight5, transformer_h_16_mlp_gate_up_proj_q_scale5, rms_norm293), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv82 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv324,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv325 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_16_mlp_down_proj_q_weight5, transformer_h_16_mlp_down_proj_q_scale5, lv82), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv322_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv325, lv321_1, transformer_h_17_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv323_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv322_1[1]
            rms_norm294: R.Tensor((1, seq_len, 3072), dtype="float16") = lv322_1[0]
            lv326 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_17_mixer_qkv_proj_q_weight5, transformer_h_17_mixer_qkv_proj_q_scale5, rms_norm294), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape580 = R.call_tir(cls.reshape4, (lv326,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape581 = R.call_tir(cls.reshape5, (reshape580,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv732 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape581), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape582 = R.call_tir(cls.reshape6, (lv732,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape583 = R.call_tir(cls.reshape7, (reshape582,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv327 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_17_mixer_out_proj_q_weight5, transformer_h_17_mixer_out_proj_q_scale5, reshape583), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv324_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv327, lv323_1, transformer_h_17_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv325_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv324_1[1]
            rms_norm295: R.Tensor((1, seq_len, 3072), dtype="float16") = lv324_1[0]
            lv328 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_17_mlp_gate_up_proj_q_weight5, transformer_h_17_mlp_gate_up_proj_q_scale5, rms_norm295), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv83 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv328,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv329 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_17_mlp_down_proj_q_weight5, transformer_h_17_mlp_down_proj_q_scale5, lv83), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv326_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv329, lv325_1, transformer_h_18_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv327_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv326_1[1]
            rms_norm296: R.Tensor((1, seq_len, 3072), dtype="float16") = lv326_1[0]
            lv330 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_18_mixer_qkv_proj_q_weight5, transformer_h_18_mixer_qkv_proj_q_scale5, rms_norm296), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape584 = R.call_tir(cls.reshape4, (lv330,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape585 = R.call_tir(cls.reshape5, (reshape584,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv737 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape585), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape586 = R.call_tir(cls.reshape6, (lv737,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape587 = R.call_tir(cls.reshape7, (reshape586,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv331 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_18_mixer_out_proj_q_weight5, transformer_h_18_mixer_out_proj_q_scale5, reshape587), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv328_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv331, lv327_1, transformer_h_18_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv329_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv328_1[1]
            rms_norm297: R.Tensor((1, seq_len, 3072), dtype="float16") = lv328_1[0]
            lv332 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_18_mlp_gate_up_proj_q_weight5, transformer_h_18_mlp_gate_up_proj_q_scale5, rms_norm297), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv84 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv332,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv333 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_18_mlp_down_proj_q_weight5, transformer_h_18_mlp_down_proj_q_scale5, lv84), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv330_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv333, lv329_1, transformer_h_19_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv331_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv330_1[1]
            rms_norm298: R.Tensor((1, seq_len, 3072), dtype="float16") = lv330_1[0]
            lv334 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_19_mixer_qkv_proj_q_weight5, transformer_h_19_mixer_qkv_proj_q_scale5, rms_norm298), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape588 = R.call_tir(cls.reshape4, (lv334,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape589 = R.call_tir(cls.reshape5, (reshape588,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv742 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape589), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape590 = R.call_tir(cls.reshape6, (lv742,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape591 = R.call_tir(cls.reshape7, (reshape590,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv335 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_19_mixer_out_proj_q_weight5, transformer_h_19_mixer_out_proj_q_scale5, reshape591), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv332_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv335, lv331_1, transformer_h_19_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv333_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv332_1[1]
            rms_norm299: R.Tensor((1, seq_len, 3072), dtype="float16") = lv332_1[0]
            lv336 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_19_mlp_gate_up_proj_q_weight5, transformer_h_19_mlp_gate_up_proj_q_scale5, rms_norm299), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv85 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv336,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv337 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_19_mlp_down_proj_q_weight5, transformer_h_19_mlp_down_proj_q_scale5, lv85), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv334_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv337, lv333_1, transformer_h_20_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv335_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv334_1[1]
            rms_norm300: R.Tensor((1, seq_len, 3072), dtype="float16") = lv334_1[0]
            lv338 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_20_mixer_qkv_proj_q_weight5, transformer_h_20_mixer_qkv_proj_q_scale5, rms_norm300), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape592 = R.call_tir(cls.reshape4, (lv338,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape593 = R.call_tir(cls.reshape5, (reshape592,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv747 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape593), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape594 = R.call_tir(cls.reshape6, (lv747,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape595 = R.call_tir(cls.reshape7, (reshape594,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv339 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_20_mixer_out_proj_q_weight5, transformer_h_20_mixer_out_proj_q_scale5, reshape595), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv336_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv339, lv335_1, transformer_h_20_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv337_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv336_1[1]
            rms_norm301: R.Tensor((1, seq_len, 3072), dtype="float16") = lv336_1[0]
            lv340 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_20_mlp_gate_up_proj_q_weight5, transformer_h_20_mlp_gate_up_proj_q_scale5, rms_norm301), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv86 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv340,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv341 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_20_mlp_down_proj_q_weight5, transformer_h_20_mlp_down_proj_q_scale5, lv86), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv338_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv341, lv337_1, transformer_h_21_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv339_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv338_1[1]
            rms_norm302: R.Tensor((1, seq_len, 3072), dtype="float16") = lv338_1[0]
            lv342 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_21_mixer_qkv_proj_q_weight5, transformer_h_21_mixer_qkv_proj_q_scale5, rms_norm302), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape596 = R.call_tir(cls.reshape4, (lv342,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape597 = R.call_tir(cls.reshape5, (reshape596,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv752 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape597), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape598 = R.call_tir(cls.reshape6, (lv752,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape599 = R.call_tir(cls.reshape7, (reshape598,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv343 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_21_mixer_out_proj_q_weight5, transformer_h_21_mixer_out_proj_q_scale5, reshape599), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv340_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv343, lv339_1, transformer_h_21_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv341_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv340_1[1]
            rms_norm303: R.Tensor((1, seq_len, 3072), dtype="float16") = lv340_1[0]
            lv344 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_21_mlp_gate_up_proj_q_weight5, transformer_h_21_mlp_gate_up_proj_q_scale5, rms_norm303), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv87 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv344,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv345 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_21_mlp_down_proj_q_weight5, transformer_h_21_mlp_down_proj_q_scale5, lv87), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv342_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv345, lv341_1, transformer_h_22_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv343_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv342_1[1]
            rms_norm304: R.Tensor((1, seq_len, 3072), dtype="float16") = lv342_1[0]
            lv346 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_22_mixer_qkv_proj_q_weight5, transformer_h_22_mixer_qkv_proj_q_scale5, rms_norm304), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape600 = R.call_tir(cls.reshape4, (lv346,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape601 = R.call_tir(cls.reshape5, (reshape600,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv757 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape601), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape602 = R.call_tir(cls.reshape6, (lv757,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape603 = R.call_tir(cls.reshape7, (reshape602,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv347 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_22_mixer_out_proj_q_weight5, transformer_h_22_mixer_out_proj_q_scale5, reshape603), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv344_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv347, lv343_1, transformer_h_22_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv345_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv344_1[1]
            rms_norm305: R.Tensor((1, seq_len, 3072), dtype="float16") = lv344_1[0]
            lv348 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_22_mlp_gate_up_proj_q_weight5, transformer_h_22_mlp_gate_up_proj_q_scale5, rms_norm305), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv88 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv348,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv349 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_22_mlp_down_proj_q_weight5, transformer_h_22_mlp_down_proj_q_scale5, lv88), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv346_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv349, lv345_1, transformer_h_23_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv347_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv346_1[1]
            rms_norm306: R.Tensor((1, seq_len, 3072), dtype="float16") = lv346_1[0]
            lv350 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_23_mixer_qkv_proj_q_weight5, transformer_h_23_mixer_qkv_proj_q_scale5, rms_norm306), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape604 = R.call_tir(cls.reshape4, (lv350,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape605 = R.call_tir(cls.reshape5, (reshape604,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv762 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape605), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape606 = R.call_tir(cls.reshape6, (lv762,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape607 = R.call_tir(cls.reshape7, (reshape606,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv351 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_23_mixer_out_proj_q_weight5, transformer_h_23_mixer_out_proj_q_scale5, reshape607), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv348_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv351, lv347_1, transformer_h_23_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv349_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv348_1[1]
            rms_norm307: R.Tensor((1, seq_len, 3072), dtype="float16") = lv348_1[0]
            lv352 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_23_mlp_gate_up_proj_q_weight5, transformer_h_23_mlp_gate_up_proj_q_scale5, rms_norm307), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv89 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv352,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv353 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_23_mlp_down_proj_q_weight5, transformer_h_23_mlp_down_proj_q_scale5, lv89), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv350_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv353, lv349_1, transformer_h_24_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv351_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv350_1[1]
            rms_norm308: R.Tensor((1, seq_len, 3072), dtype="float16") = lv350_1[0]
            lv354 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_24_mixer_qkv_proj_q_weight5, transformer_h_24_mixer_qkv_proj_q_scale5, rms_norm308), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape608 = R.call_tir(cls.reshape4, (lv354,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape609 = R.call_tir(cls.reshape5, (reshape608,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv767 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape609), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape610 = R.call_tir(cls.reshape6, (lv767,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape611 = R.call_tir(cls.reshape7, (reshape610,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv355 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_24_mixer_out_proj_q_weight5, transformer_h_24_mixer_out_proj_q_scale5, reshape611), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv352_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv355, lv351_1, transformer_h_24_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv353_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv352_1[1]
            rms_norm309: R.Tensor((1, seq_len, 3072), dtype="float16") = lv352_1[0]
            lv356 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_24_mlp_gate_up_proj_q_weight5, transformer_h_24_mlp_gate_up_proj_q_scale5, rms_norm309), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv90 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv356,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv357 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_24_mlp_down_proj_q_weight5, transformer_h_24_mlp_down_proj_q_scale5, lv90), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv354_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv357, lv353_1, transformer_h_25_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv355_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv354_1[1]
            rms_norm310: R.Tensor((1, seq_len, 3072), dtype="float16") = lv354_1[0]
            lv358 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_25_mixer_qkv_proj_q_weight5, transformer_h_25_mixer_qkv_proj_q_scale5, rms_norm310), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape612 = R.call_tir(cls.reshape4, (lv358,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape613 = R.call_tir(cls.reshape5, (reshape612,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv772 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape613), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape614 = R.call_tir(cls.reshape6, (lv772,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape615 = R.call_tir(cls.reshape7, (reshape614,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv359 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_25_mixer_out_proj_q_weight5, transformer_h_25_mixer_out_proj_q_scale5, reshape615), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv356_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv359, lv355_1, transformer_h_25_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv357_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv356_1[1]
            rms_norm311: R.Tensor((1, seq_len, 3072), dtype="float16") = lv356_1[0]
            lv360 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_25_mlp_gate_up_proj_q_weight5, transformer_h_25_mlp_gate_up_proj_q_scale5, rms_norm311), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv91 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv360,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv361 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_25_mlp_down_proj_q_weight5, transformer_h_25_mlp_down_proj_q_scale5, lv91), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv358_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv361, lv357_1, transformer_h_26_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv359_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv358_1[1]
            rms_norm312: R.Tensor((1, seq_len, 3072), dtype="float16") = lv358_1[0]
            lv362 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_26_mixer_qkv_proj_q_weight5, transformer_h_26_mixer_qkv_proj_q_scale5, rms_norm312), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape616 = R.call_tir(cls.reshape4, (lv362,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape617 = R.call_tir(cls.reshape5, (reshape616,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv777 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape617), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape618 = R.call_tir(cls.reshape6, (lv777,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape619 = R.call_tir(cls.reshape7, (reshape618,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv363 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_26_mixer_out_proj_q_weight5, transformer_h_26_mixer_out_proj_q_scale5, reshape619), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv360_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv363, lv359_1, transformer_h_26_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv361_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv360_1[1]
            rms_norm313: R.Tensor((1, seq_len, 3072), dtype="float16") = lv360_1[0]
            lv364 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_26_mlp_gate_up_proj_q_weight5, transformer_h_26_mlp_gate_up_proj_q_scale5, rms_norm313), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv92 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv364,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv365 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_26_mlp_down_proj_q_weight5, transformer_h_26_mlp_down_proj_q_scale5, lv92), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv362_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv365, lv361_1, transformer_h_27_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv363_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv362_1[1]
            rms_norm314: R.Tensor((1, seq_len, 3072), dtype="float16") = lv362_1[0]
            lv366 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_27_mixer_qkv_proj_q_weight5, transformer_h_27_mixer_qkv_proj_q_scale5, rms_norm314), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape620 = R.call_tir(cls.reshape4, (lv366,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape621 = R.call_tir(cls.reshape5, (reshape620,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv782 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape621), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape622 = R.call_tir(cls.reshape6, (lv782,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape623 = R.call_tir(cls.reshape7, (reshape622,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv367 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_27_mixer_out_proj_q_weight5, transformer_h_27_mixer_out_proj_q_scale5, reshape623), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv364_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv367, lv363_1, transformer_h_27_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv365_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv364_1[1]
            rms_norm315: R.Tensor((1, seq_len, 3072), dtype="float16") = lv364_1[0]
            lv368 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_27_mlp_gate_up_proj_q_weight5, transformer_h_27_mlp_gate_up_proj_q_scale5, rms_norm315), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv93 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv368,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv369 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_27_mlp_down_proj_q_weight5, transformer_h_27_mlp_down_proj_q_scale5, lv93), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv366_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv369, lv365_1, transformer_h_28_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv367_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv366_1[1]
            rms_norm316: R.Tensor((1, seq_len, 3072), dtype="float16") = lv366_1[0]
            lv370 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_28_mixer_qkv_proj_q_weight5, transformer_h_28_mixer_qkv_proj_q_scale5, rms_norm316), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape624 = R.call_tir(cls.reshape4, (lv370,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape625 = R.call_tir(cls.reshape5, (reshape624,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv787 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape625), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape626 = R.call_tir(cls.reshape6, (lv787,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape627 = R.call_tir(cls.reshape7, (reshape626,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv371 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_28_mixer_out_proj_q_weight5, transformer_h_28_mixer_out_proj_q_scale5, reshape627), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv368_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv371, lv367_1, transformer_h_28_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv369_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv368_1[1]
            rms_norm317: R.Tensor((1, seq_len, 3072), dtype="float16") = lv368_1[0]
            lv372 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_28_mlp_gate_up_proj_q_weight5, transformer_h_28_mlp_gate_up_proj_q_scale5, rms_norm317), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv94 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv372,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv373 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_28_mlp_down_proj_q_weight5, transformer_h_28_mlp_down_proj_q_scale5, lv94), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv370_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv373, lv369_1, transformer_h_29_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv371_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv370_1[1]
            rms_norm318: R.Tensor((1, seq_len, 3072), dtype="float16") = lv370_1[0]
            lv374 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_29_mixer_qkv_proj_q_weight5, transformer_h_29_mixer_qkv_proj_q_scale5, rms_norm318), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape628 = R.call_tir(cls.reshape4, (lv374,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape629 = R.call_tir(cls.reshape5, (reshape628,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv792 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape629), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape630 = R.call_tir(cls.reshape6, (lv792,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape631 = R.call_tir(cls.reshape7, (reshape630,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv375 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_29_mixer_out_proj_q_weight5, transformer_h_29_mixer_out_proj_q_scale5, reshape631), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv372_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv375, lv371_1, transformer_h_29_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv373_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv372_1[1]
            rms_norm319: R.Tensor((1, seq_len, 3072), dtype="float16") = lv372_1[0]
            lv376 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_29_mlp_gate_up_proj_q_weight5, transformer_h_29_mlp_gate_up_proj_q_scale5, rms_norm319), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv95 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv376,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv377 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_29_mlp_down_proj_q_weight5, transformer_h_29_mlp_down_proj_q_scale5, lv95), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv374_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv377, lv373_1, transformer_h_30_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv375_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv374_1[1]
            rms_norm320: R.Tensor((1, seq_len, 3072), dtype="float16") = lv374_1[0]
            lv378 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_30_mixer_qkv_proj_q_weight5, transformer_h_30_mixer_qkv_proj_q_scale5, rms_norm320), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape632 = R.call_tir(cls.reshape4, (lv378,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape633 = R.call_tir(cls.reshape5, (reshape632,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv797 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape633), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape634 = R.call_tir(cls.reshape6, (lv797,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape635 = R.call_tir(cls.reshape7, (reshape634,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv379 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_30_mixer_out_proj_q_weight5, transformer_h_30_mixer_out_proj_q_scale5, reshape635), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv376_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv379, lv375_1, transformer_h_30_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv377_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv376_1[1]
            rms_norm321: R.Tensor((1, seq_len, 3072), dtype="float16") = lv376_1[0]
            lv380 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_30_mlp_gate_up_proj_q_weight5, transformer_h_30_mlp_gate_up_proj_q_scale5, rms_norm321), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv96 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv380,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv381 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_30_mlp_down_proj_q_weight5, transformer_h_30_mlp_down_proj_q_scale5, lv96), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv378_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv381, lv377_1, transformer_h_31_ln_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv379_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv378_1[1]
            rms_norm322: R.Tensor((1, seq_len, 3072), dtype="float16") = lv378_1[0]
            lv382 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_31_mixer_qkv_proj_q_weight5, transformer_h_31_mixer_qkv_proj_q_scale5, rms_norm322), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape636 = R.call_tir(cls.reshape4, (lv382,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape637 = R.call_tir(cls.reshape5, (reshape636,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv802 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape637), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape638 = R.call_tir(cls.reshape6, (lv802,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape639 = R.call_tir(cls.reshape7, (reshape638,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv383 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_31_mixer_out_proj_q_weight5, transformer_h_31_mixer_out_proj_q_scale5, reshape639), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv380_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv383, lv379_1, transformer_h_31_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv381_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv380_1[1]
            rms_norm323: R.Tensor((1, seq_len, 3072), dtype="float16") = lv380_1[0]
            lv384 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_31_mlp_gate_up_proj_q_weight5, transformer_h_31_mlp_gate_up_proj_q_scale5, rms_norm323), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv97 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv384,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv385 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_31_mlp_down_proj_q_weight5, transformer_h_31_mlp_down_proj_q_scale5, lv97), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv382_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv385, lv381_1, transformer_norm_weight5), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            rms_norm324: R.Tensor((1, seq_len, 3072), dtype="float16") = lv382_1[0]
            lv386 = R.call_tir(cls.fused_dequantize5_fused_NT_matmul9_cast1, (lm_head_q_weight5, lm_head_q_scale5, rms_norm324), out_sinfo=R.Tensor((1, seq_len, vocab_size), dtype="float32"))
            gv5: R.Tuple(R.Tensor((1, seq_len, vocab_size), dtype="float32"), R.Object) = lv386, paged_kv_cache
            R.output(gv5)
        return gv5

    @R.function
    def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object:
        max_batch_size = T.int64()
        max_total_seq_len = T.int64()
        prefill_chunk_size = T.int64()
        page_size = T.int64()
        support_sliding_window = T.int64()
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 32]), R.prim_value(32), R.prim_value(32), R.prim_value(96), R.prim_value(1), R.prim_value(1), R.prim_value(T.float32(10000.0)), R.const(0.0, "float16"), cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope_longrope_scaling, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, metadata["relax.expr.Constant"][0], R.prim_value(0), sinfo_args=(R.Object,))
        return paged_kv_cache

    @R.function
    def decode(input_embed: R.Tensor((1, 1, 3072), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32064, 384), dtype="uint32"), R.Tensor((32064, 96), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor(("vocab_size", 384), dtype="uint32"), R.Tensor(("vocab_size", 96), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, "vocab_size"), dtype="float32"), R.Object):
        vocab_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            transformer_h_0_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[2]
            transformer_h_0_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[3]
            transformer_h_0_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[4]
            transformer_h_0_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[5]
            transformer_h_0_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[6]
            transformer_h_0_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[7]
            transformer_h_0_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[8]
            transformer_h_0_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[9]
            transformer_h_0_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[10]
            transformer_h_0_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[11]
            transformer_h_1_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[12]
            transformer_h_1_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[13]
            transformer_h_1_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[14]
            transformer_h_1_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[15]
            transformer_h_1_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[16]
            transformer_h_1_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[17]
            transformer_h_1_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[18]
            transformer_h_1_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[19]
            transformer_h_1_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[20]
            transformer_h_1_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[21]
            transformer_h_2_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[22]
            transformer_h_2_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[23]
            transformer_h_2_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[24]
            transformer_h_2_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[25]
            transformer_h_2_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[26]
            transformer_h_2_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[27]
            transformer_h_2_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[28]
            transformer_h_2_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[29]
            transformer_h_2_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[30]
            transformer_h_2_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[31]
            transformer_h_3_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[32]
            transformer_h_3_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[33]
            transformer_h_3_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[34]
            transformer_h_3_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[35]
            transformer_h_3_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[36]
            transformer_h_3_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[37]
            transformer_h_3_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[38]
            transformer_h_3_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[39]
            transformer_h_3_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[40]
            transformer_h_3_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[41]
            transformer_h_4_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[42]
            transformer_h_4_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[43]
            transformer_h_4_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[44]
            transformer_h_4_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[45]
            transformer_h_4_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[46]
            transformer_h_4_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[47]
            transformer_h_4_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[48]
            transformer_h_4_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[49]
            transformer_h_4_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[50]
            transformer_h_4_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[51]
            transformer_h_5_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[52]
            transformer_h_5_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[53]
            transformer_h_5_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[54]
            transformer_h_5_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[55]
            transformer_h_5_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[56]
            transformer_h_5_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[57]
            transformer_h_5_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[58]
            transformer_h_5_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[59]
            transformer_h_5_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[60]
            transformer_h_5_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[61]
            transformer_h_6_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[62]
            transformer_h_6_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[63]
            transformer_h_6_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[64]
            transformer_h_6_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[65]
            transformer_h_6_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[66]
            transformer_h_6_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[67]
            transformer_h_6_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[68]
            transformer_h_6_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[69]
            transformer_h_6_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[70]
            transformer_h_6_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[71]
            transformer_h_7_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[72]
            transformer_h_7_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[73]
            transformer_h_7_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[74]
            transformer_h_7_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[75]
            transformer_h_7_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[76]
            transformer_h_7_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[77]
            transformer_h_7_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[78]
            transformer_h_7_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[79]
            transformer_h_7_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[80]
            transformer_h_7_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[81]
            transformer_h_8_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[82]
            transformer_h_8_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[83]
            transformer_h_8_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[84]
            transformer_h_8_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[85]
            transformer_h_8_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[86]
            transformer_h_8_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[87]
            transformer_h_8_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[88]
            transformer_h_8_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[89]
            transformer_h_8_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[90]
            transformer_h_8_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[91]
            transformer_h_9_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[92]
            transformer_h_9_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[93]
            transformer_h_9_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[94]
            transformer_h_9_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[95]
            transformer_h_9_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[96]
            transformer_h_9_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[97]
            transformer_h_9_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[98]
            transformer_h_9_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[99]
            transformer_h_9_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[100]
            transformer_h_9_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[101]
            transformer_h_10_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[102]
            transformer_h_10_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[103]
            transformer_h_10_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[104]
            transformer_h_10_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[105]
            transformer_h_10_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[106]
            transformer_h_10_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[107]
            transformer_h_10_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[108]
            transformer_h_10_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[109]
            transformer_h_10_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[110]
            transformer_h_10_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[111]
            transformer_h_11_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[112]
            transformer_h_11_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[113]
            transformer_h_11_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[114]
            transformer_h_11_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[115]
            transformer_h_11_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[116]
            transformer_h_11_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[117]
            transformer_h_11_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[118]
            transformer_h_11_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[119]
            transformer_h_11_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[120]
            transformer_h_11_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[121]
            transformer_h_12_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[122]
            transformer_h_12_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[123]
            transformer_h_12_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[124]
            transformer_h_12_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[125]
            transformer_h_12_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[126]
            transformer_h_12_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[127]
            transformer_h_12_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[128]
            transformer_h_12_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[129]
            transformer_h_12_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[130]
            transformer_h_12_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[131]
            transformer_h_13_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[132]
            transformer_h_13_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[133]
            transformer_h_13_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[134]
            transformer_h_13_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[135]
            transformer_h_13_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[136]
            transformer_h_13_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[137]
            transformer_h_13_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[138]
            transformer_h_13_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[139]
            transformer_h_13_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[140]
            transformer_h_13_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[141]
            transformer_h_14_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[142]
            transformer_h_14_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[143]
            transformer_h_14_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[144]
            transformer_h_14_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[145]
            transformer_h_14_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[146]
            transformer_h_14_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[147]
            transformer_h_14_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[148]
            transformer_h_14_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[149]
            transformer_h_14_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[150]
            transformer_h_14_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[151]
            transformer_h_15_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[152]
            transformer_h_15_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[153]
            transformer_h_15_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[154]
            transformer_h_15_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[155]
            transformer_h_15_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[156]
            transformer_h_15_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[157]
            transformer_h_15_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[158]
            transformer_h_15_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[159]
            transformer_h_15_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[160]
            transformer_h_15_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[161]
            transformer_h_16_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[162]
            transformer_h_16_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[163]
            transformer_h_16_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[164]
            transformer_h_16_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[165]
            transformer_h_16_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[166]
            transformer_h_16_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[167]
            transformer_h_16_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[168]
            transformer_h_16_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[169]
            transformer_h_16_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[170]
            transformer_h_16_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[171]
            transformer_h_17_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[172]
            transformer_h_17_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[173]
            transformer_h_17_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[174]
            transformer_h_17_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[175]
            transformer_h_17_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[176]
            transformer_h_17_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[177]
            transformer_h_17_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[178]
            transformer_h_17_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[179]
            transformer_h_17_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[180]
            transformer_h_17_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[181]
            transformer_h_18_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[182]
            transformer_h_18_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[183]
            transformer_h_18_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[184]
            transformer_h_18_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[185]
            transformer_h_18_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[186]
            transformer_h_18_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[187]
            transformer_h_18_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[188]
            transformer_h_18_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[189]
            transformer_h_18_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[190]
            transformer_h_18_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[191]
            transformer_h_19_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[192]
            transformer_h_19_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[193]
            transformer_h_19_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[194]
            transformer_h_19_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[195]
            transformer_h_19_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[196]
            transformer_h_19_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[197]
            transformer_h_19_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[198]
            transformer_h_19_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[199]
            transformer_h_19_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[200]
            transformer_h_19_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[201]
            transformer_h_20_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[202]
            transformer_h_20_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[203]
            transformer_h_20_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[204]
            transformer_h_20_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[205]
            transformer_h_20_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[206]
            transformer_h_20_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[207]
            transformer_h_20_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[208]
            transformer_h_20_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[209]
            transformer_h_20_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[210]
            transformer_h_20_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[211]
            transformer_h_21_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[212]
            transformer_h_21_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[213]
            transformer_h_21_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[214]
            transformer_h_21_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[215]
            transformer_h_21_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[216]
            transformer_h_21_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[217]
            transformer_h_21_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[218]
            transformer_h_21_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[219]
            transformer_h_21_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[220]
            transformer_h_21_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[221]
            transformer_h_22_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[222]
            transformer_h_22_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[223]
            transformer_h_22_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[224]
            transformer_h_22_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[225]
            transformer_h_22_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[226]
            transformer_h_22_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[227]
            transformer_h_22_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[228]
            transformer_h_22_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[229]
            transformer_h_22_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[230]
            transformer_h_22_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[231]
            transformer_h_23_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[232]
            transformer_h_23_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[233]
            transformer_h_23_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[234]
            transformer_h_23_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[235]
            transformer_h_23_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[236]
            transformer_h_23_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[237]
            transformer_h_23_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[238]
            transformer_h_23_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[239]
            transformer_h_23_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[240]
            transformer_h_23_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[241]
            transformer_h_24_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[242]
            transformer_h_24_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[243]
            transformer_h_24_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[244]
            transformer_h_24_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[245]
            transformer_h_24_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[246]
            transformer_h_24_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[247]
            transformer_h_24_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[248]
            transformer_h_24_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[249]
            transformer_h_24_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[250]
            transformer_h_24_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[251]
            transformer_h_25_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[252]
            transformer_h_25_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[253]
            transformer_h_25_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[254]
            transformer_h_25_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[255]
            transformer_h_25_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[256]
            transformer_h_25_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[257]
            transformer_h_25_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[258]
            transformer_h_25_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[259]
            transformer_h_25_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[260]
            transformer_h_25_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[261]
            transformer_h_26_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[262]
            transformer_h_26_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[263]
            transformer_h_26_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[264]
            transformer_h_26_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[265]
            transformer_h_26_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[266]
            transformer_h_26_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[267]
            transformer_h_26_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[268]
            transformer_h_26_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[269]
            transformer_h_26_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[270]
            transformer_h_26_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[271]
            transformer_h_27_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[272]
            transformer_h_27_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[273]
            transformer_h_27_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[274]
            transformer_h_27_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[275]
            transformer_h_27_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[276]
            transformer_h_27_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[277]
            transformer_h_27_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[278]
            transformer_h_27_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[279]
            transformer_h_27_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[280]
            transformer_h_27_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[281]
            transformer_h_28_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[282]
            transformer_h_28_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[283]
            transformer_h_28_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[284]
            transformer_h_28_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[285]
            transformer_h_28_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[286]
            transformer_h_28_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[287]
            transformer_h_28_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[288]
            transformer_h_28_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[289]
            transformer_h_28_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[290]
            transformer_h_28_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[291]
            transformer_h_29_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[292]
            transformer_h_29_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[293]
            transformer_h_29_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[294]
            transformer_h_29_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[295]
            transformer_h_29_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[296]
            transformer_h_29_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[297]
            transformer_h_29_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[298]
            transformer_h_29_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[299]
            transformer_h_29_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[300]
            transformer_h_29_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[301]
            transformer_h_30_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[302]
            transformer_h_30_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[303]
            transformer_h_30_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[304]
            transformer_h_30_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[305]
            transformer_h_30_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[306]
            transformer_h_30_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[307]
            transformer_h_30_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[308]
            transformer_h_30_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[309]
            transformer_h_30_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[310]
            transformer_h_30_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[311]
            transformer_h_31_ln_weight2: R.Tensor((3072,), dtype="float16") = packed_params[312]
            transformer_h_31_mixer_qkv_proj_q_weight2: R.Tensor((9216, 384), dtype="uint32") = packed_params[313]
            transformer_h_31_mixer_qkv_proj_q_scale2: R.Tensor((9216, 96), dtype="float16") = packed_params[314]
            transformer_h_31_mixer_out_proj_q_weight2: R.Tensor((3072, 384), dtype="uint32") = packed_params[315]
            transformer_h_31_mixer_out_proj_q_scale2: R.Tensor((3072, 96), dtype="float16") = packed_params[316]
            transformer_h_31_mlp_gate_up_proj_q_weight2: R.Tensor((16384, 384), dtype="uint32") = packed_params[317]
            transformer_h_31_mlp_gate_up_proj_q_scale2: R.Tensor((16384, 96), dtype="float16") = packed_params[318]
            transformer_h_31_mlp_down_proj_q_weight2: R.Tensor((3072, 1024), dtype="uint32") = packed_params[319]
            transformer_h_31_mlp_down_proj_q_scale2: R.Tensor((3072, 256), dtype="float16") = packed_params[320]
            transformer_h_31_post_attention_layernorm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[321]
            transformer_norm_weight2: R.Tensor((3072,), dtype="float16") = packed_params[322]
            lm_head_q_weight2: R.Tensor((vocab_size, 384), dtype="uint32") = packed_params[323]
            lm_head_q_scale2: R.Tensor((vocab_size, 96), dtype="float16") = packed_params[324]
            rms_norm65 = R.call_tir(cls.rms_norm2, (input_embed, transformer_h_0_ln_weight2), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv387 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_0_mixer_qkv_proj_q_weight2, transformer_h_0_mixer_qkv_proj_q_scale2, rms_norm65), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv99 = R.call_tir(cls.fused_reshape8_reshape9, (lv387,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv164 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), lv99), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv100 = R.call_tir(cls.fused_reshape10_reshape11, (lv164,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv388 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_0_mixer_out_proj_q_weight2, transformer_h_0_mixer_out_proj_q_scale2, lv100), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv384 = R.call_tir(cls.fuse_add_norm_prefill, (lv388, input_embed, transformer_h_0_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv385: R.Tensor((1, 1, 3072), dtype="float16") = lv384[1]
            rms_norm66: R.Tensor((1, 1, 3072), dtype="float16") = lv384[0]
            lv389 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_0_mlp_gate_up_proj_q_weight2, transformer_h_0_mlp_gate_up_proj_q_scale2, rms_norm66), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv101 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv389,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv390 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_0_mlp_down_proj_q_weight2, transformer_h_0_mlp_down_proj_q_scale2, lv101), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv386 = R.call_tir(cls.fuse_add_norm_prefill, (lv390, lv385, transformer_h_1_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv387_1: R.Tensor((1, 1, 3072), dtype="float16") = lv386[1]
            rms_norm67: R.Tensor((1, 1, 3072), dtype="float16") = lv386[0]
            lv391 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_1_mixer_qkv_proj_q_weight2, transformer_h_1_mixer_qkv_proj_q_scale2, rms_norm67), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv102 = R.call_tir(cls.fused_reshape8_reshape9, (lv391,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv169 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), lv102), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv103 = R.call_tir(cls.fused_reshape10_reshape11, (lv169,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv392 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_1_mixer_out_proj_q_weight2, transformer_h_1_mixer_out_proj_q_scale2, lv103), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv388_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv392, lv387_1, transformer_h_1_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv389_1: R.Tensor((1, 1, 3072), dtype="float16") = lv388_1[1]
            rms_norm68: R.Tensor((1, 1, 3072), dtype="float16") = lv388_1[0]
            lv393 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_1_mlp_gate_up_proj_q_weight2, transformer_h_1_mlp_gate_up_proj_q_scale2, rms_norm68), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv104 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv393,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv394 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_1_mlp_down_proj_q_weight2, transformer_h_1_mlp_down_proj_q_scale2, lv104), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv390_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv394, lv389_1, transformer_h_2_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv391_1: R.Tensor((1, 1, 3072), dtype="float16") = lv390_1[1]
            rms_norm69: R.Tensor((1, 1, 3072), dtype="float16") = lv390_1[0]
            lv395 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_2_mixer_qkv_proj_q_weight2, transformer_h_2_mixer_qkv_proj_q_scale2, rms_norm69), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv105 = R.call_tir(cls.fused_reshape8_reshape9, (lv395,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv174 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), lv105), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv106 = R.call_tir(cls.fused_reshape10_reshape11, (lv174,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv396 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_2_mixer_out_proj_q_weight2, transformer_h_2_mixer_out_proj_q_scale2, lv106), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv392_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv396, lv391_1, transformer_h_2_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv393_1: R.Tensor((1, 1, 3072), dtype="float16") = lv392_1[1]
            rms_norm70: R.Tensor((1, 1, 3072), dtype="float16") = lv392_1[0]
            lv397 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_2_mlp_gate_up_proj_q_weight2, transformer_h_2_mlp_gate_up_proj_q_scale2, rms_norm70), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv107 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv397,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv398 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_2_mlp_down_proj_q_weight2, transformer_h_2_mlp_down_proj_q_scale2, lv107), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv394_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv398, lv393_1, transformer_h_3_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv395_1: R.Tensor((1, 1, 3072), dtype="float16") = lv394_1[1]
            rms_norm71: R.Tensor((1, 1, 3072), dtype="float16") = lv394_1[0]
            lv399 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_3_mixer_qkv_proj_q_weight2, transformer_h_3_mixer_qkv_proj_q_scale2, rms_norm71), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv108 = R.call_tir(cls.fused_reshape8_reshape9, (lv399,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv179 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), lv108), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv109 = R.call_tir(cls.fused_reshape10_reshape11, (lv179,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv400 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_3_mixer_out_proj_q_weight2, transformer_h_3_mixer_out_proj_q_scale2, lv109), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv396_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv400, lv395_1, transformer_h_3_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv397_1: R.Tensor((1, 1, 3072), dtype="float16") = lv396_1[1]
            rms_norm72: R.Tensor((1, 1, 3072), dtype="float16") = lv396_1[0]
            lv401 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_3_mlp_gate_up_proj_q_weight2, transformer_h_3_mlp_gate_up_proj_q_scale2, rms_norm72), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv110 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv401,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv402 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_3_mlp_down_proj_q_weight2, transformer_h_3_mlp_down_proj_q_scale2, lv110), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv398_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv402, lv397_1, transformer_h_4_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv399_1: R.Tensor((1, 1, 3072), dtype="float16") = lv398_1[1]
            rms_norm73: R.Tensor((1, 1, 3072), dtype="float16") = lv398_1[0]
            lv403 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_4_mixer_qkv_proj_q_weight2, transformer_h_4_mixer_qkv_proj_q_scale2, rms_norm73), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv111 = R.call_tir(cls.fused_reshape8_reshape9, (lv403,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv184 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), lv111), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv112 = R.call_tir(cls.fused_reshape10_reshape11, (lv184,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv404 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_4_mixer_out_proj_q_weight2, transformer_h_4_mixer_out_proj_q_scale2, lv112), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv400_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv404, lv399_1, transformer_h_4_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv401_1: R.Tensor((1, 1, 3072), dtype="float16") = lv400_1[1]
            rms_norm74: R.Tensor((1, 1, 3072), dtype="float16") = lv400_1[0]
            lv405 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_4_mlp_gate_up_proj_q_weight2, transformer_h_4_mlp_gate_up_proj_q_scale2, rms_norm74), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv113 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv405,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv406 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_4_mlp_down_proj_q_weight2, transformer_h_4_mlp_down_proj_q_scale2, lv113), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv402_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv406, lv401_1, transformer_h_5_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv403_1: R.Tensor((1, 1, 3072), dtype="float16") = lv402_1[1]
            rms_norm75: R.Tensor((1, 1, 3072), dtype="float16") = lv402_1[0]
            lv407 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_5_mixer_qkv_proj_q_weight2, transformer_h_5_mixer_qkv_proj_q_scale2, rms_norm75), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv114 = R.call_tir(cls.fused_reshape8_reshape9, (lv407,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv189 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), lv114), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv115 = R.call_tir(cls.fused_reshape10_reshape11, (lv189,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv408 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_5_mixer_out_proj_q_weight2, transformer_h_5_mixer_out_proj_q_scale2, lv115), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv404_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv408, lv403_1, transformer_h_5_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv405_1: R.Tensor((1, 1, 3072), dtype="float16") = lv404_1[1]
            rms_norm76: R.Tensor((1, 1, 3072), dtype="float16") = lv404_1[0]
            lv409 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_5_mlp_gate_up_proj_q_weight2, transformer_h_5_mlp_gate_up_proj_q_scale2, rms_norm76), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv116 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv409,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv410 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_5_mlp_down_proj_q_weight2, transformer_h_5_mlp_down_proj_q_scale2, lv116), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv406_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv410, lv405_1, transformer_h_6_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv407_1: R.Tensor((1, 1, 3072), dtype="float16") = lv406_1[1]
            rms_norm77: R.Tensor((1, 1, 3072), dtype="float16") = lv406_1[0]
            lv411 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_6_mixer_qkv_proj_q_weight2, transformer_h_6_mixer_qkv_proj_q_scale2, rms_norm77), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv117 = R.call_tir(cls.fused_reshape8_reshape9, (lv411,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv194 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), lv117), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv118 = R.call_tir(cls.fused_reshape10_reshape11, (lv194,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv412 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_6_mixer_out_proj_q_weight2, transformer_h_6_mixer_out_proj_q_scale2, lv118), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv408_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv412, lv407_1, transformer_h_6_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv409_1: R.Tensor((1, 1, 3072), dtype="float16") = lv408_1[1]
            rms_norm78: R.Tensor((1, 1, 3072), dtype="float16") = lv408_1[0]
            lv413 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_6_mlp_gate_up_proj_q_weight2, transformer_h_6_mlp_gate_up_proj_q_scale2, rms_norm78), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv119 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv413,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv414 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_6_mlp_down_proj_q_weight2, transformer_h_6_mlp_down_proj_q_scale2, lv119), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv410_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv414, lv409_1, transformer_h_7_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv411_1: R.Tensor((1, 1, 3072), dtype="float16") = lv410_1[1]
            rms_norm79: R.Tensor((1, 1, 3072), dtype="float16") = lv410_1[0]
            lv415 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_7_mixer_qkv_proj_q_weight2, transformer_h_7_mixer_qkv_proj_q_scale2, rms_norm79), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv120 = R.call_tir(cls.fused_reshape8_reshape9, (lv415,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv199 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), lv120), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv121 = R.call_tir(cls.fused_reshape10_reshape11, (lv199,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv416 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_7_mixer_out_proj_q_weight2, transformer_h_7_mixer_out_proj_q_scale2, lv121), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv412_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv416, lv411_1, transformer_h_7_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv413_1: R.Tensor((1, 1, 3072), dtype="float16") = lv412_1[1]
            rms_norm80: R.Tensor((1, 1, 3072), dtype="float16") = lv412_1[0]
            lv417 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_7_mlp_gate_up_proj_q_weight2, transformer_h_7_mlp_gate_up_proj_q_scale2, rms_norm80), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv122 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv417,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv418 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_7_mlp_down_proj_q_weight2, transformer_h_7_mlp_down_proj_q_scale2, lv122), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv414_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv418, lv413_1, transformer_h_8_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv415_1: R.Tensor((1, 1, 3072), dtype="float16") = lv414_1[1]
            rms_norm81: R.Tensor((1, 1, 3072), dtype="float16") = lv414_1[0]
            lv419 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_8_mixer_qkv_proj_q_weight2, transformer_h_8_mixer_qkv_proj_q_scale2, rms_norm81), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv123 = R.call_tir(cls.fused_reshape8_reshape9, (lv419,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv204 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), lv123), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv124 = R.call_tir(cls.fused_reshape10_reshape11, (lv204,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv420 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_8_mixer_out_proj_q_weight2, transformer_h_8_mixer_out_proj_q_scale2, lv124), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv416_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv420, lv415_1, transformer_h_8_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv417_1: R.Tensor((1, 1, 3072), dtype="float16") = lv416_1[1]
            rms_norm82: R.Tensor((1, 1, 3072), dtype="float16") = lv416_1[0]
            lv421 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_8_mlp_gate_up_proj_q_weight2, transformer_h_8_mlp_gate_up_proj_q_scale2, rms_norm82), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv125 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv421,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv422 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_8_mlp_down_proj_q_weight2, transformer_h_8_mlp_down_proj_q_scale2, lv125), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv418_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv422, lv417_1, transformer_h_9_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv419_1: R.Tensor((1, 1, 3072), dtype="float16") = lv418_1[1]
            rms_norm83: R.Tensor((1, 1, 3072), dtype="float16") = lv418_1[0]
            lv423 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_9_mixer_qkv_proj_q_weight2, transformer_h_9_mixer_qkv_proj_q_scale2, rms_norm83), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv126 = R.call_tir(cls.fused_reshape8_reshape9, (lv423,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv209 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), lv126), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv127 = R.call_tir(cls.fused_reshape10_reshape11, (lv209,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv424 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_9_mixer_out_proj_q_weight2, transformer_h_9_mixer_out_proj_q_scale2, lv127), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv420_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv424, lv419_1, transformer_h_9_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv421_1: R.Tensor((1, 1, 3072), dtype="float16") = lv420_1[1]
            rms_norm84: R.Tensor((1, 1, 3072), dtype="float16") = lv420_1[0]
            lv425 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_9_mlp_gate_up_proj_q_weight2, transformer_h_9_mlp_gate_up_proj_q_scale2, rms_norm84), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv128 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv425,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv426 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_9_mlp_down_proj_q_weight2, transformer_h_9_mlp_down_proj_q_scale2, lv128), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv422_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv426, lv421_1, transformer_h_10_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv423_1: R.Tensor((1, 1, 3072), dtype="float16") = lv422_1[1]
            rms_norm85: R.Tensor((1, 1, 3072), dtype="float16") = lv422_1[0]
            lv427 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_10_mixer_qkv_proj_q_weight2, transformer_h_10_mixer_qkv_proj_q_scale2, rms_norm85), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv129 = R.call_tir(cls.fused_reshape8_reshape9, (lv427,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv214 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), lv129), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv130 = R.call_tir(cls.fused_reshape10_reshape11, (lv214,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv428 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_10_mixer_out_proj_q_weight2, transformer_h_10_mixer_out_proj_q_scale2, lv130), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv424_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv428, lv423_1, transformer_h_10_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv425_1: R.Tensor((1, 1, 3072), dtype="float16") = lv424_1[1]
            rms_norm86: R.Tensor((1, 1, 3072), dtype="float16") = lv424_1[0]
            lv429 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_10_mlp_gate_up_proj_q_weight2, transformer_h_10_mlp_gate_up_proj_q_scale2, rms_norm86), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv131 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv429,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv430 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_10_mlp_down_proj_q_weight2, transformer_h_10_mlp_down_proj_q_scale2, lv131), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv426_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv430, lv425_1, transformer_h_11_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv427_1: R.Tensor((1, 1, 3072), dtype="float16") = lv426_1[1]
            rms_norm87: R.Tensor((1, 1, 3072), dtype="float16") = lv426_1[0]
            lv431 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_11_mixer_qkv_proj_q_weight2, transformer_h_11_mixer_qkv_proj_q_scale2, rms_norm87), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv132 = R.call_tir(cls.fused_reshape8_reshape9, (lv431,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv219 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), lv132), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv133 = R.call_tir(cls.fused_reshape10_reshape11, (lv219,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv432 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_11_mixer_out_proj_q_weight2, transformer_h_11_mixer_out_proj_q_scale2, lv133), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv428_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv432, lv427_1, transformer_h_11_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv429_1: R.Tensor((1, 1, 3072), dtype="float16") = lv428_1[1]
            rms_norm88: R.Tensor((1, 1, 3072), dtype="float16") = lv428_1[0]
            lv433 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_11_mlp_gate_up_proj_q_weight2, transformer_h_11_mlp_gate_up_proj_q_scale2, rms_norm88), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv134 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv433,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv434 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_11_mlp_down_proj_q_weight2, transformer_h_11_mlp_down_proj_q_scale2, lv134), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv430_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv434, lv429_1, transformer_h_12_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv431_1: R.Tensor((1, 1, 3072), dtype="float16") = lv430_1[1]
            rms_norm89: R.Tensor((1, 1, 3072), dtype="float16") = lv430_1[0]
            lv435 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_12_mixer_qkv_proj_q_weight2, transformer_h_12_mixer_qkv_proj_q_scale2, rms_norm89), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv135 = R.call_tir(cls.fused_reshape8_reshape9, (lv435,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv224 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), lv135), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv136 = R.call_tir(cls.fused_reshape10_reshape11, (lv224,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv436 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_12_mixer_out_proj_q_weight2, transformer_h_12_mixer_out_proj_q_scale2, lv136), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv432_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv436, lv431_1, transformer_h_12_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv433_1: R.Tensor((1, 1, 3072), dtype="float16") = lv432_1[1]
            rms_norm90: R.Tensor((1, 1, 3072), dtype="float16") = lv432_1[0]
            lv437 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_12_mlp_gate_up_proj_q_weight2, transformer_h_12_mlp_gate_up_proj_q_scale2, rms_norm90), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv137 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv437,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv438 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_12_mlp_down_proj_q_weight2, transformer_h_12_mlp_down_proj_q_scale2, lv137), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv434_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv438, lv433_1, transformer_h_13_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv435_1: R.Tensor((1, 1, 3072), dtype="float16") = lv434_1[1]
            rms_norm91: R.Tensor((1, 1, 3072), dtype="float16") = lv434_1[0]
            lv439 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_13_mixer_qkv_proj_q_weight2, transformer_h_13_mixer_qkv_proj_q_scale2, rms_norm91), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv138 = R.call_tir(cls.fused_reshape8_reshape9, (lv439,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv229 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), lv138), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv139 = R.call_tir(cls.fused_reshape10_reshape11, (lv229,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv440 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_13_mixer_out_proj_q_weight2, transformer_h_13_mixer_out_proj_q_scale2, lv139), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv436_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv440, lv435_1, transformer_h_13_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv437_1: R.Tensor((1, 1, 3072), dtype="float16") = lv436_1[1]
            rms_norm92: R.Tensor((1, 1, 3072), dtype="float16") = lv436_1[0]
            lv441 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_13_mlp_gate_up_proj_q_weight2, transformer_h_13_mlp_gate_up_proj_q_scale2, rms_norm92), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv140 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv441,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv442 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_13_mlp_down_proj_q_weight2, transformer_h_13_mlp_down_proj_q_scale2, lv140), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv438_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv442, lv437_1, transformer_h_14_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv439_1: R.Tensor((1, 1, 3072), dtype="float16") = lv438_1[1]
            rms_norm93: R.Tensor((1, 1, 3072), dtype="float16") = lv438_1[0]
            lv443 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_14_mixer_qkv_proj_q_weight2, transformer_h_14_mixer_qkv_proj_q_scale2, rms_norm93), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv141 = R.call_tir(cls.fused_reshape8_reshape9, (lv443,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv234 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), lv141), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv142 = R.call_tir(cls.fused_reshape10_reshape11, (lv234,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv444 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_14_mixer_out_proj_q_weight2, transformer_h_14_mixer_out_proj_q_scale2, lv142), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv440_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv444, lv439_1, transformer_h_14_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv441_1: R.Tensor((1, 1, 3072), dtype="float16") = lv440_1[1]
            rms_norm94: R.Tensor((1, 1, 3072), dtype="float16") = lv440_1[0]
            lv445 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_14_mlp_gate_up_proj_q_weight2, transformer_h_14_mlp_gate_up_proj_q_scale2, rms_norm94), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv143 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv445,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv446 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_14_mlp_down_proj_q_weight2, transformer_h_14_mlp_down_proj_q_scale2, lv143), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv442_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv446, lv441_1, transformer_h_15_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv443_1: R.Tensor((1, 1, 3072), dtype="float16") = lv442_1[1]
            rms_norm95: R.Tensor((1, 1, 3072), dtype="float16") = lv442_1[0]
            lv447 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_15_mixer_qkv_proj_q_weight2, transformer_h_15_mixer_qkv_proj_q_scale2, rms_norm95), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv144 = R.call_tir(cls.fused_reshape8_reshape9, (lv447,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv239 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), lv144), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv145 = R.call_tir(cls.fused_reshape10_reshape11, (lv239,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv448 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_15_mixer_out_proj_q_weight2, transformer_h_15_mixer_out_proj_q_scale2, lv145), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv444_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv448, lv443_1, transformer_h_15_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv445_1: R.Tensor((1, 1, 3072), dtype="float16") = lv444_1[1]
            rms_norm96: R.Tensor((1, 1, 3072), dtype="float16") = lv444_1[0]
            lv449 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_15_mlp_gate_up_proj_q_weight2, transformer_h_15_mlp_gate_up_proj_q_scale2, rms_norm96), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv146 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv449,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv450 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_15_mlp_down_proj_q_weight2, transformer_h_15_mlp_down_proj_q_scale2, lv146), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv446_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv450, lv445_1, transformer_h_16_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv447_1: R.Tensor((1, 1, 3072), dtype="float16") = lv446_1[1]
            rms_norm97: R.Tensor((1, 1, 3072), dtype="float16") = lv446_1[0]
            lv451 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_16_mixer_qkv_proj_q_weight2, transformer_h_16_mixer_qkv_proj_q_scale2, rms_norm97), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv147 = R.call_tir(cls.fused_reshape8_reshape9, (lv451,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv244 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), lv147), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv148 = R.call_tir(cls.fused_reshape10_reshape11, (lv244,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv452 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_16_mixer_out_proj_q_weight2, transformer_h_16_mixer_out_proj_q_scale2, lv148), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv448_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv452, lv447_1, transformer_h_16_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv449_1: R.Tensor((1, 1, 3072), dtype="float16") = lv448_1[1]
            rms_norm98: R.Tensor((1, 1, 3072), dtype="float16") = lv448_1[0]
            lv453 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_16_mlp_gate_up_proj_q_weight2, transformer_h_16_mlp_gate_up_proj_q_scale2, rms_norm98), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv149 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv453,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv454 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_16_mlp_down_proj_q_weight2, transformer_h_16_mlp_down_proj_q_scale2, lv149), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv450_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv454, lv449_1, transformer_h_17_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv451_1: R.Tensor((1, 1, 3072), dtype="float16") = lv450_1[1]
            rms_norm99: R.Tensor((1, 1, 3072), dtype="float16") = lv450_1[0]
            lv455 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_17_mixer_qkv_proj_q_weight2, transformer_h_17_mixer_qkv_proj_q_scale2, rms_norm99), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv150 = R.call_tir(cls.fused_reshape8_reshape9, (lv455,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv249 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), lv150), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv151 = R.call_tir(cls.fused_reshape10_reshape11, (lv249,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv456 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_17_mixer_out_proj_q_weight2, transformer_h_17_mixer_out_proj_q_scale2, lv151), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv452_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv456, lv451_1, transformer_h_17_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv453_1: R.Tensor((1, 1, 3072), dtype="float16") = lv452_1[1]
            rms_norm100: R.Tensor((1, 1, 3072), dtype="float16") = lv452_1[0]
            lv457 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_17_mlp_gate_up_proj_q_weight2, transformer_h_17_mlp_gate_up_proj_q_scale2, rms_norm100), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv152 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv457,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv458 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_17_mlp_down_proj_q_weight2, transformer_h_17_mlp_down_proj_q_scale2, lv152), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv454_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv458, lv453_1, transformer_h_18_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv455_1: R.Tensor((1, 1, 3072), dtype="float16") = lv454_1[1]
            rms_norm101: R.Tensor((1, 1, 3072), dtype="float16") = lv454_1[0]
            lv459 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_18_mixer_qkv_proj_q_weight2, transformer_h_18_mixer_qkv_proj_q_scale2, rms_norm101), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv153 = R.call_tir(cls.fused_reshape8_reshape9, (lv459,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv254 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), lv153), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv154 = R.call_tir(cls.fused_reshape10_reshape11, (lv254,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv460 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_18_mixer_out_proj_q_weight2, transformer_h_18_mixer_out_proj_q_scale2, lv154), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv456_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv460, lv455_1, transformer_h_18_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv457_1: R.Tensor((1, 1, 3072), dtype="float16") = lv456_1[1]
            rms_norm102: R.Tensor((1, 1, 3072), dtype="float16") = lv456_1[0]
            lv461 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_18_mlp_gate_up_proj_q_weight2, transformer_h_18_mlp_gate_up_proj_q_scale2, rms_norm102), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv155 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv461,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv462 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_18_mlp_down_proj_q_weight2, transformer_h_18_mlp_down_proj_q_scale2, lv155), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv458_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv462, lv457_1, transformer_h_19_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv459_1: R.Tensor((1, 1, 3072), dtype="float16") = lv458_1[1]
            rms_norm103: R.Tensor((1, 1, 3072), dtype="float16") = lv458_1[0]
            lv463 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_19_mixer_qkv_proj_q_weight2, transformer_h_19_mixer_qkv_proj_q_scale2, rms_norm103), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv156 = R.call_tir(cls.fused_reshape8_reshape9, (lv463,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv259 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), lv156), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv157 = R.call_tir(cls.fused_reshape10_reshape11, (lv259,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv464 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_19_mixer_out_proj_q_weight2, transformer_h_19_mixer_out_proj_q_scale2, lv157), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv460_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv464, lv459_1, transformer_h_19_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv461_1: R.Tensor((1, 1, 3072), dtype="float16") = lv460_1[1]
            rms_norm104: R.Tensor((1, 1, 3072), dtype="float16") = lv460_1[0]
            lv465 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_19_mlp_gate_up_proj_q_weight2, transformer_h_19_mlp_gate_up_proj_q_scale2, rms_norm104), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv158 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv465,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv466 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_19_mlp_down_proj_q_weight2, transformer_h_19_mlp_down_proj_q_scale2, lv158), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv462_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv466, lv461_1, transformer_h_20_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv463_1: R.Tensor((1, 1, 3072), dtype="float16") = lv462_1[1]
            rms_norm105: R.Tensor((1, 1, 3072), dtype="float16") = lv462_1[0]
            lv467 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_20_mixer_qkv_proj_q_weight2, transformer_h_20_mixer_qkv_proj_q_scale2, rms_norm105), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv159 = R.call_tir(cls.fused_reshape8_reshape9, (lv467,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv264 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), lv159), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv160 = R.call_tir(cls.fused_reshape10_reshape11, (lv264,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv468 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_20_mixer_out_proj_q_weight2, transformer_h_20_mixer_out_proj_q_scale2, lv160), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv464_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv468, lv463_1, transformer_h_20_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv465_1: R.Tensor((1, 1, 3072), dtype="float16") = lv464_1[1]
            rms_norm106: R.Tensor((1, 1, 3072), dtype="float16") = lv464_1[0]
            lv469 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_20_mlp_gate_up_proj_q_weight2, transformer_h_20_mlp_gate_up_proj_q_scale2, rms_norm106), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv161 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv469,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv470 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_20_mlp_down_proj_q_weight2, transformer_h_20_mlp_down_proj_q_scale2, lv161), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv466_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv470, lv465_1, transformer_h_21_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv467_1: R.Tensor((1, 1, 3072), dtype="float16") = lv466_1[1]
            rms_norm107: R.Tensor((1, 1, 3072), dtype="float16") = lv466_1[0]
            lv471 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_21_mixer_qkv_proj_q_weight2, transformer_h_21_mixer_qkv_proj_q_scale2, rms_norm107), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv162 = R.call_tir(cls.fused_reshape8_reshape9, (lv471,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv269 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), lv162), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv163 = R.call_tir(cls.fused_reshape10_reshape11, (lv269,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv472 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_21_mixer_out_proj_q_weight2, transformer_h_21_mixer_out_proj_q_scale2, lv163), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv468_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv472, lv467_1, transformer_h_21_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv469_1: R.Tensor((1, 1, 3072), dtype="float16") = lv468_1[1]
            rms_norm108: R.Tensor((1, 1, 3072), dtype="float16") = lv468_1[0]
            lv473 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_21_mlp_gate_up_proj_q_weight2, transformer_h_21_mlp_gate_up_proj_q_scale2, rms_norm108), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv164_1 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv473,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv474 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_21_mlp_down_proj_q_weight2, transformer_h_21_mlp_down_proj_q_scale2, lv164_1), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv470_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv474, lv469_1, transformer_h_22_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv471_1: R.Tensor((1, 1, 3072), dtype="float16") = lv470_1[1]
            rms_norm109: R.Tensor((1, 1, 3072), dtype="float16") = lv470_1[0]
            lv475 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_22_mixer_qkv_proj_q_weight2, transformer_h_22_mixer_qkv_proj_q_scale2, rms_norm109), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv165 = R.call_tir(cls.fused_reshape8_reshape9, (lv475,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv274 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), lv165), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv166 = R.call_tir(cls.fused_reshape10_reshape11, (lv274,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv476 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_22_mixer_out_proj_q_weight2, transformer_h_22_mixer_out_proj_q_scale2, lv166), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv472_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv476, lv471_1, transformer_h_22_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv473_1: R.Tensor((1, 1, 3072), dtype="float16") = lv472_1[1]
            rms_norm110: R.Tensor((1, 1, 3072), dtype="float16") = lv472_1[0]
            lv477 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_22_mlp_gate_up_proj_q_weight2, transformer_h_22_mlp_gate_up_proj_q_scale2, rms_norm110), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv167 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv477,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv478 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_22_mlp_down_proj_q_weight2, transformer_h_22_mlp_down_proj_q_scale2, lv167), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv474_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv478, lv473_1, transformer_h_23_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv475_1: R.Tensor((1, 1, 3072), dtype="float16") = lv474_1[1]
            rms_norm111: R.Tensor((1, 1, 3072), dtype="float16") = lv474_1[0]
            lv479 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_23_mixer_qkv_proj_q_weight2, transformer_h_23_mixer_qkv_proj_q_scale2, rms_norm111), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv168 = R.call_tir(cls.fused_reshape8_reshape9, (lv479,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv279 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), lv168), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv169_1 = R.call_tir(cls.fused_reshape10_reshape11, (lv279,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv480 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_23_mixer_out_proj_q_weight2, transformer_h_23_mixer_out_proj_q_scale2, lv169_1), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv476_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv480, lv475_1, transformer_h_23_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv477_1: R.Tensor((1, 1, 3072), dtype="float16") = lv476_1[1]
            rms_norm112: R.Tensor((1, 1, 3072), dtype="float16") = lv476_1[0]
            lv481 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_23_mlp_gate_up_proj_q_weight2, transformer_h_23_mlp_gate_up_proj_q_scale2, rms_norm112), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv170 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv481,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv482 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_23_mlp_down_proj_q_weight2, transformer_h_23_mlp_down_proj_q_scale2, lv170), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv478_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv482, lv477_1, transformer_h_24_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv479_1: R.Tensor((1, 1, 3072), dtype="float16") = lv478_1[1]
            rms_norm113: R.Tensor((1, 1, 3072), dtype="float16") = lv478_1[0]
            lv483 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_24_mixer_qkv_proj_q_weight2, transformer_h_24_mixer_qkv_proj_q_scale2, rms_norm113), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv171 = R.call_tir(cls.fused_reshape8_reshape9, (lv483,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv284 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), lv171), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv172 = R.call_tir(cls.fused_reshape10_reshape11, (lv284,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv484 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_24_mixer_out_proj_q_weight2, transformer_h_24_mixer_out_proj_q_scale2, lv172), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv480_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv484, lv479_1, transformer_h_24_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv481_1: R.Tensor((1, 1, 3072), dtype="float16") = lv480_1[1]
            rms_norm114: R.Tensor((1, 1, 3072), dtype="float16") = lv480_1[0]
            lv485 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_24_mlp_gate_up_proj_q_weight2, transformer_h_24_mlp_gate_up_proj_q_scale2, rms_norm114), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv173 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv485,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv486 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_24_mlp_down_proj_q_weight2, transformer_h_24_mlp_down_proj_q_scale2, lv173), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv482_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv486, lv481_1, transformer_h_25_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv483_1: R.Tensor((1, 1, 3072), dtype="float16") = lv482_1[1]
            rms_norm115: R.Tensor((1, 1, 3072), dtype="float16") = lv482_1[0]
            lv487 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_25_mixer_qkv_proj_q_weight2, transformer_h_25_mixer_qkv_proj_q_scale2, rms_norm115), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv174_1 = R.call_tir(cls.fused_reshape8_reshape9, (lv487,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv289 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), lv174_1), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv175 = R.call_tir(cls.fused_reshape10_reshape11, (lv289,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv488 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_25_mixer_out_proj_q_weight2, transformer_h_25_mixer_out_proj_q_scale2, lv175), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv484_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv488, lv483_1, transformer_h_25_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv485_1: R.Tensor((1, 1, 3072), dtype="float16") = lv484_1[1]
            rms_norm116: R.Tensor((1, 1, 3072), dtype="float16") = lv484_1[0]
            lv489 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_25_mlp_gate_up_proj_q_weight2, transformer_h_25_mlp_gate_up_proj_q_scale2, rms_norm116), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv176 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv489,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv490 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_25_mlp_down_proj_q_weight2, transformer_h_25_mlp_down_proj_q_scale2, lv176), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv486_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv490, lv485_1, transformer_h_26_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv487_1: R.Tensor((1, 1, 3072), dtype="float16") = lv486_1[1]
            rms_norm117: R.Tensor((1, 1, 3072), dtype="float16") = lv486_1[0]
            lv491 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_26_mixer_qkv_proj_q_weight2, transformer_h_26_mixer_qkv_proj_q_scale2, rms_norm117), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv177 = R.call_tir(cls.fused_reshape8_reshape9, (lv491,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv294 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), lv177), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv178 = R.call_tir(cls.fused_reshape10_reshape11, (lv294,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv492 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_26_mixer_out_proj_q_weight2, transformer_h_26_mixer_out_proj_q_scale2, lv178), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv488_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv492, lv487_1, transformer_h_26_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv489_1: R.Tensor((1, 1, 3072), dtype="float16") = lv488_1[1]
            rms_norm118: R.Tensor((1, 1, 3072), dtype="float16") = lv488_1[0]
            lv493 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_26_mlp_gate_up_proj_q_weight2, transformer_h_26_mlp_gate_up_proj_q_scale2, rms_norm118), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv179_1 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv493,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv494 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_26_mlp_down_proj_q_weight2, transformer_h_26_mlp_down_proj_q_scale2, lv179_1), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv490_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv494, lv489_1, transformer_h_27_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv491_1: R.Tensor((1, 1, 3072), dtype="float16") = lv490_1[1]
            rms_norm119: R.Tensor((1, 1, 3072), dtype="float16") = lv490_1[0]
            lv495 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_27_mixer_qkv_proj_q_weight2, transformer_h_27_mixer_qkv_proj_q_scale2, rms_norm119), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv180 = R.call_tir(cls.fused_reshape8_reshape9, (lv495,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv299 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), lv180), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv181 = R.call_tir(cls.fused_reshape10_reshape11, (lv299,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv496 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_27_mixer_out_proj_q_weight2, transformer_h_27_mixer_out_proj_q_scale2, lv181), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv492_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv496, lv491_1, transformer_h_27_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv493_1: R.Tensor((1, 1, 3072), dtype="float16") = lv492_1[1]
            rms_norm120: R.Tensor((1, 1, 3072), dtype="float16") = lv492_1[0]
            lv497 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_27_mlp_gate_up_proj_q_weight2, transformer_h_27_mlp_gate_up_proj_q_scale2, rms_norm120), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv182 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv497,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv498 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_27_mlp_down_proj_q_weight2, transformer_h_27_mlp_down_proj_q_scale2, lv182), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv494_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv498, lv493_1, transformer_h_28_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv495_1: R.Tensor((1, 1, 3072), dtype="float16") = lv494_1[1]
            rms_norm121: R.Tensor((1, 1, 3072), dtype="float16") = lv494_1[0]
            lv499 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_28_mixer_qkv_proj_q_weight2, transformer_h_28_mixer_qkv_proj_q_scale2, rms_norm121), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv183 = R.call_tir(cls.fused_reshape8_reshape9, (lv499,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv304 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), lv183), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv184_1 = R.call_tir(cls.fused_reshape10_reshape11, (lv304,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv500 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_28_mixer_out_proj_q_weight2, transformer_h_28_mixer_out_proj_q_scale2, lv184_1), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv496_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv500, lv495_1, transformer_h_28_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv497_1: R.Tensor((1, 1, 3072), dtype="float16") = lv496_1[1]
            rms_norm122: R.Tensor((1, 1, 3072), dtype="float16") = lv496_1[0]
            lv501 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_28_mlp_gate_up_proj_q_weight2, transformer_h_28_mlp_gate_up_proj_q_scale2, rms_norm122), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv185 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv501,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv502 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_28_mlp_down_proj_q_weight2, transformer_h_28_mlp_down_proj_q_scale2, lv185), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv498_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv502, lv497_1, transformer_h_29_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv499_1: R.Tensor((1, 1, 3072), dtype="float16") = lv498_1[1]
            rms_norm123: R.Tensor((1, 1, 3072), dtype="float16") = lv498_1[0]
            lv503 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_29_mixer_qkv_proj_q_weight2, transformer_h_29_mixer_qkv_proj_q_scale2, rms_norm123), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv186 = R.call_tir(cls.fused_reshape8_reshape9, (lv503,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv309 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), lv186), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv187 = R.call_tir(cls.fused_reshape10_reshape11, (lv309,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv504 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_29_mixer_out_proj_q_weight2, transformer_h_29_mixer_out_proj_q_scale2, lv187), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv500_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv504, lv499_1, transformer_h_29_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv501_1: R.Tensor((1, 1, 3072), dtype="float16") = lv500_1[1]
            rms_norm124: R.Tensor((1, 1, 3072), dtype="float16") = lv500_1[0]
            lv505 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_29_mlp_gate_up_proj_q_weight2, transformer_h_29_mlp_gate_up_proj_q_scale2, rms_norm124), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv188 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv505,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv506 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_29_mlp_down_proj_q_weight2, transformer_h_29_mlp_down_proj_q_scale2, lv188), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv502_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv506, lv501_1, transformer_h_30_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv503_1: R.Tensor((1, 1, 3072), dtype="float16") = lv502_1[1]
            rms_norm125: R.Tensor((1, 1, 3072), dtype="float16") = lv502_1[0]
            lv507 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_30_mixer_qkv_proj_q_weight2, transformer_h_30_mixer_qkv_proj_q_scale2, rms_norm125), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv189_1 = R.call_tir(cls.fused_reshape8_reshape9, (lv507,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv314 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), lv189_1), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv190 = R.call_tir(cls.fused_reshape10_reshape11, (lv314,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv508 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_30_mixer_out_proj_q_weight2, transformer_h_30_mixer_out_proj_q_scale2, lv190), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv504_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv508, lv503_1, transformer_h_30_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv505_1: R.Tensor((1, 1, 3072), dtype="float16") = lv504_1[1]
            rms_norm126: R.Tensor((1, 1, 3072), dtype="float16") = lv504_1[0]
            lv509 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_30_mlp_gate_up_proj_q_weight2, transformer_h_30_mlp_gate_up_proj_q_scale2, rms_norm126), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv191 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv509,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv510 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_30_mlp_down_proj_q_weight2, transformer_h_30_mlp_down_proj_q_scale2, lv191), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv506_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv510, lv505_1, transformer_h_31_ln_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv507_1: R.Tensor((1, 1, 3072), dtype="float16") = lv506_1[1]
            rms_norm127: R.Tensor((1, 1, 3072), dtype="float16") = lv506_1[0]
            lv511 = R.call_tir(cls.fused_dequantize1_NT_matmul10, (transformer_h_31_mixer_qkv_proj_q_weight2, transformer_h_31_mixer_qkv_proj_q_scale2, rms_norm127), out_sinfo=R.Tensor((1, 1, 9216), dtype="float16"))
            lv192 = R.call_tir(cls.fused_reshape8_reshape9, (lv511,), out_sinfo=R.Tensor((1, 96, 96), dtype="float16"))
            lv319 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), lv192), out_sinfo=R.Tensor((1, 32, 96), dtype="float16"))
            lv193 = R.call_tir(cls.fused_reshape10_reshape11, (lv319,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv512 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (transformer_h_31_mixer_out_proj_q_weight2, transformer_h_31_mixer_out_proj_q_scale2, lv193), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv508_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv512, lv507_1, transformer_h_31_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            lv509_1: R.Tensor((1, 1, 3072), dtype="float16") = lv508_1[1]
            rms_norm128: R.Tensor((1, 1, 3072), dtype="float16") = lv508_1[0]
            lv513 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (transformer_h_31_mlp_gate_up_proj_q_weight2, transformer_h_31_mlp_gate_up_proj_q_scale2, rms_norm128), out_sinfo=R.Tensor((1, 1, 16384), dtype="float16"))
            lv194_1 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv513,), out_sinfo=R.Tensor((1, 1, 8192), dtype="float16"))
            lv514 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (transformer_h_31_mlp_down_proj_q_weight2, transformer_h_31_mlp_down_proj_q_scale2, lv194_1), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv510_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv514, lv509_1, transformer_norm_weight2), out_sinfo=[R.Tensor((1, 1, 3072), dtype="float16"), R.Tensor((1, 1, 3072), dtype="float16")])
            rms_norm129: R.Tensor((1, 1, 3072), dtype="float16") = lv510_1[0]
            lv515 = R.call_tir(cls.fused_dequantize5_fused_NT_matmul14_cast2, (lm_head_q_weight2, lm_head_q_scale2, rms_norm129), out_sinfo=R.Tensor((1, 1, vocab_size), dtype="float32"))
            gv2: R.Tuple(R.Tensor((1, 1, vocab_size), dtype="float32"), R.Object) = lv515, paged_kv_cache
            R.output(gv2)
        return gv2

    @R.function
    def embed(input_ids: R.Tensor(("seq_len",), dtype="int32"), packed_params: R.Tuple(R.Tensor((32064, 384), dtype="uint32"), R.Tensor((32064, 96), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor(("vocab_size", 384), dtype="uint32"), R.Tensor(("vocab_size", 96), dtype="float16"))) -> R.Tensor(("seq_len", 3072), dtype="float16"):
        seq_len = T.int64()
        vocab_size = T.int64()
        R.func_attr({"num_input": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            transformer_embd_q_weight: R.Tensor((32064, 384), dtype="uint32") = packed_params[0]
            transformer_embd_q_scale: R.Tensor((32064, 96), dtype="float16") = packed_params[1]
            gv = R.call_tir(cls.fused_dequantize_take1, (transformer_embd_q_weight, transformer_embd_q_scale, input_ids), out_sinfo=R.Tensor((seq_len, 3072), dtype="float16"))
            R.output(gv)
        return gv

    @R.function
    def multinomial_from_uniform(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), uniform_samples: R.Tensor(("num_samples",), dtype="float32"), sample_indices: R.Tensor(("num_samples",), dtype="int32")) -> R.Tensor(("num_samples",), dtype="int32"):
        num_samples = T.int64(is_size_var=True)
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            uniform_samples_1: R.Tensor((num_samples, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", uniform_samples, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="float32"),))
            sample_indices_1: R.Tensor((num_samples, 1), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", sample_indices, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="int32"),))
            nn_multinomial_from_uniform = R.call_tir(cls.parallel_sampling_from_prob, (probs, uniform_samples_1, sample_indices_1), out_sinfo=R.Tensor((num_samples, 1), dtype="int32"))
            gv: R.Tensor((num_samples,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", nn_multinomial_from_uniform, R.shape([num_samples]), sinfo_args=(R.Tensor((num_samples,), dtype="int32"),))
            R.output(gv)
        return gv

    @R.function
    def prefill(input_embed: R.Tensor((1, "seq_len", 3072), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((32064, 384), dtype="uint32"), R.Tensor((32064, 96), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((9216, 384), dtype="uint32"), R.Tensor((9216, 96), dtype="float16"), R.Tensor((3072, 384), dtype="uint32"), R.Tensor((3072, 96), dtype="float16"), R.Tensor((16384, 384), dtype="uint32"), R.Tensor((16384, 96), dtype="float16"), R.Tensor((3072, 1024), dtype="uint32"), R.Tensor((3072, 256), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor((3072,), dtype="float16"), R.Tensor(("vocab_size", 384), dtype="uint32"), R.Tensor(("vocab_size", 96), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, "vocab_size"), dtype="float32"), R.Object):
        vocab_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            transformer_h_0_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[2]
            transformer_h_0_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[3]
            transformer_h_0_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[4]
            transformer_h_0_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[5]
            transformer_h_0_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[6]
            transformer_h_0_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[7]
            transformer_h_0_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[8]
            transformer_h_0_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[9]
            transformer_h_0_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[10]
            transformer_h_0_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[11]
            transformer_h_1_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[12]
            transformer_h_1_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[13]
            transformer_h_1_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[14]
            transformer_h_1_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[15]
            transformer_h_1_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[16]
            transformer_h_1_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[17]
            transformer_h_1_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[18]
            transformer_h_1_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[19]
            transformer_h_1_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[20]
            transformer_h_1_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[21]
            transformer_h_2_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[22]
            transformer_h_2_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[23]
            transformer_h_2_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[24]
            transformer_h_2_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[25]
            transformer_h_2_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[26]
            transformer_h_2_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[27]
            transformer_h_2_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[28]
            transformer_h_2_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[29]
            transformer_h_2_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[30]
            transformer_h_2_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[31]
            transformer_h_3_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[32]
            transformer_h_3_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[33]
            transformer_h_3_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[34]
            transformer_h_3_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[35]
            transformer_h_3_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[36]
            transformer_h_3_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[37]
            transformer_h_3_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[38]
            transformer_h_3_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[39]
            transformer_h_3_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[40]
            transformer_h_3_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[41]
            transformer_h_4_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[42]
            transformer_h_4_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[43]
            transformer_h_4_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[44]
            transformer_h_4_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[45]
            transformer_h_4_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[46]
            transformer_h_4_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[47]
            transformer_h_4_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[48]
            transformer_h_4_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[49]
            transformer_h_4_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[50]
            transformer_h_4_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[51]
            transformer_h_5_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[52]
            transformer_h_5_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[53]
            transformer_h_5_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[54]
            transformer_h_5_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[55]
            transformer_h_5_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[56]
            transformer_h_5_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[57]
            transformer_h_5_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[58]
            transformer_h_5_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[59]
            transformer_h_5_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[60]
            transformer_h_5_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[61]
            transformer_h_6_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[62]
            transformer_h_6_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[63]
            transformer_h_6_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[64]
            transformer_h_6_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[65]
            transformer_h_6_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[66]
            transformer_h_6_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[67]
            transformer_h_6_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[68]
            transformer_h_6_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[69]
            transformer_h_6_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[70]
            transformer_h_6_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[71]
            transformer_h_7_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[72]
            transformer_h_7_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[73]
            transformer_h_7_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[74]
            transformer_h_7_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[75]
            transformer_h_7_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[76]
            transformer_h_7_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[77]
            transformer_h_7_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[78]
            transformer_h_7_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[79]
            transformer_h_7_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[80]
            transformer_h_7_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[81]
            transformer_h_8_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[82]
            transformer_h_8_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[83]
            transformer_h_8_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[84]
            transformer_h_8_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[85]
            transformer_h_8_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[86]
            transformer_h_8_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[87]
            transformer_h_8_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[88]
            transformer_h_8_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[89]
            transformer_h_8_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[90]
            transformer_h_8_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[91]
            transformer_h_9_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[92]
            transformer_h_9_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[93]
            transformer_h_9_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[94]
            transformer_h_9_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[95]
            transformer_h_9_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[96]
            transformer_h_9_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[97]
            transformer_h_9_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[98]
            transformer_h_9_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[99]
            transformer_h_9_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[100]
            transformer_h_9_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[101]
            transformer_h_10_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[102]
            transformer_h_10_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[103]
            transformer_h_10_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[104]
            transformer_h_10_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[105]
            transformer_h_10_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[106]
            transformer_h_10_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[107]
            transformer_h_10_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[108]
            transformer_h_10_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[109]
            transformer_h_10_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[110]
            transformer_h_10_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[111]
            transformer_h_11_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[112]
            transformer_h_11_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[113]
            transformer_h_11_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[114]
            transformer_h_11_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[115]
            transformer_h_11_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[116]
            transformer_h_11_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[117]
            transformer_h_11_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[118]
            transformer_h_11_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[119]
            transformer_h_11_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[120]
            transformer_h_11_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[121]
            transformer_h_12_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[122]
            transformer_h_12_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[123]
            transformer_h_12_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[124]
            transformer_h_12_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[125]
            transformer_h_12_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[126]
            transformer_h_12_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[127]
            transformer_h_12_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[128]
            transformer_h_12_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[129]
            transformer_h_12_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[130]
            transformer_h_12_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[131]
            transformer_h_13_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[132]
            transformer_h_13_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[133]
            transformer_h_13_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[134]
            transformer_h_13_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[135]
            transformer_h_13_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[136]
            transformer_h_13_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[137]
            transformer_h_13_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[138]
            transformer_h_13_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[139]
            transformer_h_13_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[140]
            transformer_h_13_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[141]
            transformer_h_14_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[142]
            transformer_h_14_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[143]
            transformer_h_14_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[144]
            transformer_h_14_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[145]
            transformer_h_14_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[146]
            transformer_h_14_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[147]
            transformer_h_14_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[148]
            transformer_h_14_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[149]
            transformer_h_14_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[150]
            transformer_h_14_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[151]
            transformer_h_15_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[152]
            transformer_h_15_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[153]
            transformer_h_15_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[154]
            transformer_h_15_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[155]
            transformer_h_15_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[156]
            transformer_h_15_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[157]
            transformer_h_15_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[158]
            transformer_h_15_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[159]
            transformer_h_15_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[160]
            transformer_h_15_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[161]
            transformer_h_16_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[162]
            transformer_h_16_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[163]
            transformer_h_16_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[164]
            transformer_h_16_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[165]
            transformer_h_16_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[166]
            transformer_h_16_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[167]
            transformer_h_16_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[168]
            transformer_h_16_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[169]
            transformer_h_16_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[170]
            transformer_h_16_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[171]
            transformer_h_17_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[172]
            transformer_h_17_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[173]
            transformer_h_17_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[174]
            transformer_h_17_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[175]
            transformer_h_17_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[176]
            transformer_h_17_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[177]
            transformer_h_17_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[178]
            transformer_h_17_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[179]
            transformer_h_17_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[180]
            transformer_h_17_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[181]
            transformer_h_18_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[182]
            transformer_h_18_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[183]
            transformer_h_18_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[184]
            transformer_h_18_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[185]
            transformer_h_18_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[186]
            transformer_h_18_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[187]
            transformer_h_18_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[188]
            transformer_h_18_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[189]
            transformer_h_18_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[190]
            transformer_h_18_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[191]
            transformer_h_19_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[192]
            transformer_h_19_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[193]
            transformer_h_19_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[194]
            transformer_h_19_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[195]
            transformer_h_19_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[196]
            transformer_h_19_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[197]
            transformer_h_19_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[198]
            transformer_h_19_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[199]
            transformer_h_19_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[200]
            transformer_h_19_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[201]
            transformer_h_20_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[202]
            transformer_h_20_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[203]
            transformer_h_20_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[204]
            transformer_h_20_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[205]
            transformer_h_20_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[206]
            transformer_h_20_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[207]
            transformer_h_20_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[208]
            transformer_h_20_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[209]
            transformer_h_20_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[210]
            transformer_h_20_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[211]
            transformer_h_21_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[212]
            transformer_h_21_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[213]
            transformer_h_21_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[214]
            transformer_h_21_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[215]
            transformer_h_21_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[216]
            transformer_h_21_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[217]
            transformer_h_21_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[218]
            transformer_h_21_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[219]
            transformer_h_21_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[220]
            transformer_h_21_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[221]
            transformer_h_22_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[222]
            transformer_h_22_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[223]
            transformer_h_22_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[224]
            transformer_h_22_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[225]
            transformer_h_22_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[226]
            transformer_h_22_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[227]
            transformer_h_22_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[228]
            transformer_h_22_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[229]
            transformer_h_22_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[230]
            transformer_h_22_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[231]
            transformer_h_23_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[232]
            transformer_h_23_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[233]
            transformer_h_23_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[234]
            transformer_h_23_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[235]
            transformer_h_23_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[236]
            transformer_h_23_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[237]
            transformer_h_23_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[238]
            transformer_h_23_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[239]
            transformer_h_23_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[240]
            transformer_h_23_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[241]
            transformer_h_24_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[242]
            transformer_h_24_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[243]
            transformer_h_24_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[244]
            transformer_h_24_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[245]
            transformer_h_24_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[246]
            transformer_h_24_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[247]
            transformer_h_24_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[248]
            transformer_h_24_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[249]
            transformer_h_24_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[250]
            transformer_h_24_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[251]
            transformer_h_25_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[252]
            transformer_h_25_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[253]
            transformer_h_25_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[254]
            transformer_h_25_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[255]
            transformer_h_25_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[256]
            transformer_h_25_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[257]
            transformer_h_25_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[258]
            transformer_h_25_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[259]
            transformer_h_25_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[260]
            transformer_h_25_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[261]
            transformer_h_26_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[262]
            transformer_h_26_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[263]
            transformer_h_26_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[264]
            transformer_h_26_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[265]
            transformer_h_26_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[266]
            transformer_h_26_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[267]
            transformer_h_26_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[268]
            transformer_h_26_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[269]
            transformer_h_26_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[270]
            transformer_h_26_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[271]
            transformer_h_27_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[272]
            transformer_h_27_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[273]
            transformer_h_27_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[274]
            transformer_h_27_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[275]
            transformer_h_27_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[276]
            transformer_h_27_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[277]
            transformer_h_27_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[278]
            transformer_h_27_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[279]
            transformer_h_27_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[280]
            transformer_h_27_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[281]
            transformer_h_28_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[282]
            transformer_h_28_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[283]
            transformer_h_28_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[284]
            transformer_h_28_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[285]
            transformer_h_28_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[286]
            transformer_h_28_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[287]
            transformer_h_28_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[288]
            transformer_h_28_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[289]
            transformer_h_28_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[290]
            transformer_h_28_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[291]
            transformer_h_29_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[292]
            transformer_h_29_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[293]
            transformer_h_29_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[294]
            transformer_h_29_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[295]
            transformer_h_29_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[296]
            transformer_h_29_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[297]
            transformer_h_29_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[298]
            transformer_h_29_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[299]
            transformer_h_29_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[300]
            transformer_h_29_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[301]
            transformer_h_30_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[302]
            transformer_h_30_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[303]
            transformer_h_30_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[304]
            transformer_h_30_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[305]
            transformer_h_30_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[306]
            transformer_h_30_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[307]
            transformer_h_30_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[308]
            transformer_h_30_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[309]
            transformer_h_30_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[310]
            transformer_h_30_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[311]
            transformer_h_31_ln_weight1: R.Tensor((3072,), dtype="float16") = packed_params[312]
            transformer_h_31_mixer_qkv_proj_q_weight1: R.Tensor((9216, 384), dtype="uint32") = packed_params[313]
            transformer_h_31_mixer_qkv_proj_q_scale1: R.Tensor((9216, 96), dtype="float16") = packed_params[314]
            transformer_h_31_mixer_out_proj_q_weight1: R.Tensor((3072, 384), dtype="uint32") = packed_params[315]
            transformer_h_31_mixer_out_proj_q_scale1: R.Tensor((3072, 96), dtype="float16") = packed_params[316]
            transformer_h_31_mlp_gate_up_proj_q_weight1: R.Tensor((16384, 384), dtype="uint32") = packed_params[317]
            transformer_h_31_mlp_gate_up_proj_q_scale1: R.Tensor((16384, 96), dtype="float16") = packed_params[318]
            transformer_h_31_mlp_down_proj_q_weight1: R.Tensor((3072, 1024), dtype="uint32") = packed_params[319]
            transformer_h_31_mlp_down_proj_q_scale1: R.Tensor((3072, 256), dtype="float16") = packed_params[320]
            transformer_h_31_post_attention_layernorm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[321]
            transformer_norm_weight1: R.Tensor((3072,), dtype="float16") = packed_params[322]
            lm_head_q_weight1: R.Tensor((vocab_size, 384), dtype="uint32") = packed_params[323]
            lm_head_q_scale1: R.Tensor((vocab_size, 96), dtype="float16") = packed_params[324]
            rms_norm = R.call_tir(cls.rms_norm1, (input_embed, transformer_h_0_ln_weight1), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv516 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_0_mixer_qkv_proj_q_weight1, transformer_h_0_mixer_qkv_proj_q_scale1, rms_norm), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape = R.call_tir(cls.reshape4, (lv516,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape1 = R.call_tir(cls.reshape5, (reshape,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv2 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape1), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape2 = R.call_tir(cls.reshape6, (lv2,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape3 = R.call_tir(cls.reshape7, (reshape2,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv517 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_0_mixer_out_proj_q_weight1, transformer_h_0_mixer_out_proj_q_scale1, reshape3), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv512 = R.call_tir(cls.fuse_add_norm_prefill, (lv517, input_embed, transformer_h_0_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv513: R.Tensor((1, seq_len, 3072), dtype="float16") = lv512[1]
            rms_norm1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv512[0]
            lv518 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_0_mlp_gate_up_proj_q_weight1, transformer_h_0_mlp_gate_up_proj_q_scale1, rms_norm1), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv196 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv518,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv519 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_0_mlp_down_proj_q_weight1, transformer_h_0_mlp_down_proj_q_scale1, lv196), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv514 = R.call_tir(cls.fuse_add_norm_prefill, (lv519, lv513, transformer_h_1_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv515: R.Tensor((1, seq_len, 3072), dtype="float16") = lv514[1]
            rms_norm2: R.Tensor((1, seq_len, 3072), dtype="float16") = lv514[0]
            lv520 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_1_mixer_qkv_proj_q_weight1, transformer_h_1_mixer_qkv_proj_q_scale1, rms_norm2), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape4 = R.call_tir(cls.reshape4, (lv520,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape5 = R.call_tir(cls.reshape5, (reshape4,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv7 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape5), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape6 = R.call_tir(cls.reshape6, (lv7,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape7 = R.call_tir(cls.reshape7, (reshape6,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv521 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_1_mixer_out_proj_q_weight1, transformer_h_1_mixer_out_proj_q_scale1, reshape7), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv516_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv521, lv515, transformer_h_1_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv517_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv516_1[1]
            rms_norm3: R.Tensor((1, seq_len, 3072), dtype="float16") = lv516_1[0]
            lv522 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_1_mlp_gate_up_proj_q_weight1, transformer_h_1_mlp_gate_up_proj_q_scale1, rms_norm3), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv197 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv522,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv523 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_1_mlp_down_proj_q_weight1, transformer_h_1_mlp_down_proj_q_scale1, lv197), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv518_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv523, lv517_1, transformer_h_2_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv519_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv518_1[1]
            rms_norm4: R.Tensor((1, seq_len, 3072), dtype="float16") = lv518_1[0]
            lv524 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_2_mixer_qkv_proj_q_weight1, transformer_h_2_mixer_qkv_proj_q_scale1, rms_norm4), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape8 = R.call_tir(cls.reshape4, (lv524,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape9 = R.call_tir(cls.reshape5, (reshape8,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv12 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape9), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape10 = R.call_tir(cls.reshape6, (lv12,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape11 = R.call_tir(cls.reshape7, (reshape10,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv525 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_2_mixer_out_proj_q_weight1, transformer_h_2_mixer_out_proj_q_scale1, reshape11), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv520_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv525, lv519_1, transformer_h_2_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv521_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv520_1[1]
            rms_norm5: R.Tensor((1, seq_len, 3072), dtype="float16") = lv520_1[0]
            lv526 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_2_mlp_gate_up_proj_q_weight1, transformer_h_2_mlp_gate_up_proj_q_scale1, rms_norm5), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv198 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv526,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv527 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_2_mlp_down_proj_q_weight1, transformer_h_2_mlp_down_proj_q_scale1, lv198), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv522_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv527, lv521_1, transformer_h_3_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv523_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv522_1[1]
            rms_norm6: R.Tensor((1, seq_len, 3072), dtype="float16") = lv522_1[0]
            lv528 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_3_mixer_qkv_proj_q_weight1, transformer_h_3_mixer_qkv_proj_q_scale1, rms_norm6), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape12 = R.call_tir(cls.reshape4, (lv528,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape13 = R.call_tir(cls.reshape5, (reshape12,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv17 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape13), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape14 = R.call_tir(cls.reshape6, (lv17,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape15 = R.call_tir(cls.reshape7, (reshape14,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv529 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_3_mixer_out_proj_q_weight1, transformer_h_3_mixer_out_proj_q_scale1, reshape15), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv524_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv529, lv523_1, transformer_h_3_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv525_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv524_1[1]
            rms_norm7: R.Tensor((1, seq_len, 3072), dtype="float16") = lv524_1[0]
            lv530 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_3_mlp_gate_up_proj_q_weight1, transformer_h_3_mlp_gate_up_proj_q_scale1, rms_norm7), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv199 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv530,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv531 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_3_mlp_down_proj_q_weight1, transformer_h_3_mlp_down_proj_q_scale1, lv199), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv526_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv531, lv525_1, transformer_h_4_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv527_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv526_1[1]
            rms_norm8: R.Tensor((1, seq_len, 3072), dtype="float16") = lv526_1[0]
            lv532 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_4_mixer_qkv_proj_q_weight1, transformer_h_4_mixer_qkv_proj_q_scale1, rms_norm8), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape16 = R.call_tir(cls.reshape4, (lv532,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape17 = R.call_tir(cls.reshape5, (reshape16,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv22 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape17), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape18 = R.call_tir(cls.reshape6, (lv22,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape19 = R.call_tir(cls.reshape7, (reshape18,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv533 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_4_mixer_out_proj_q_weight1, transformer_h_4_mixer_out_proj_q_scale1, reshape19), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv528_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv533, lv527_1, transformer_h_4_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv529_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv528_1[1]
            rms_norm9: R.Tensor((1, seq_len, 3072), dtype="float16") = lv528_1[0]
            lv534 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_4_mlp_gate_up_proj_q_weight1, transformer_h_4_mlp_gate_up_proj_q_scale1, rms_norm9), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv200 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv534,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv535 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_4_mlp_down_proj_q_weight1, transformer_h_4_mlp_down_proj_q_scale1, lv200), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv530_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv535, lv529_1, transformer_h_5_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv531_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv530_1[1]
            rms_norm10: R.Tensor((1, seq_len, 3072), dtype="float16") = lv530_1[0]
            lv536 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_5_mixer_qkv_proj_q_weight1, transformer_h_5_mixer_qkv_proj_q_scale1, rms_norm10), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape20 = R.call_tir(cls.reshape4, (lv536,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape21 = R.call_tir(cls.reshape5, (reshape20,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv27 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape21), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape22 = R.call_tir(cls.reshape6, (lv27,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape23 = R.call_tir(cls.reshape7, (reshape22,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv537 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_5_mixer_out_proj_q_weight1, transformer_h_5_mixer_out_proj_q_scale1, reshape23), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv532_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv537, lv531_1, transformer_h_5_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv533_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv532_1[1]
            rms_norm11: R.Tensor((1, seq_len, 3072), dtype="float16") = lv532_1[0]
            lv538 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_5_mlp_gate_up_proj_q_weight1, transformer_h_5_mlp_gate_up_proj_q_scale1, rms_norm11), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv201 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv538,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv539 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_5_mlp_down_proj_q_weight1, transformer_h_5_mlp_down_proj_q_scale1, lv201), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv534_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv539, lv533_1, transformer_h_6_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv535_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv534_1[1]
            rms_norm12: R.Tensor((1, seq_len, 3072), dtype="float16") = lv534_1[0]
            lv540 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_6_mixer_qkv_proj_q_weight1, transformer_h_6_mixer_qkv_proj_q_scale1, rms_norm12), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape24 = R.call_tir(cls.reshape4, (lv540,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape25 = R.call_tir(cls.reshape5, (reshape24,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv32 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape25), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape26 = R.call_tir(cls.reshape6, (lv32,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape27 = R.call_tir(cls.reshape7, (reshape26,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv541 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_6_mixer_out_proj_q_weight1, transformer_h_6_mixer_out_proj_q_scale1, reshape27), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv536_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv541, lv535_1, transformer_h_6_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv537_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv536_1[1]
            rms_norm13: R.Tensor((1, seq_len, 3072), dtype="float16") = lv536_1[0]
            lv542 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_6_mlp_gate_up_proj_q_weight1, transformer_h_6_mlp_gate_up_proj_q_scale1, rms_norm13), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv202 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv542,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv543 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_6_mlp_down_proj_q_weight1, transformer_h_6_mlp_down_proj_q_scale1, lv202), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv538_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv543, lv537_1, transformer_h_7_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv539_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv538_1[1]
            rms_norm14: R.Tensor((1, seq_len, 3072), dtype="float16") = lv538_1[0]
            lv544 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_7_mixer_qkv_proj_q_weight1, transformer_h_7_mixer_qkv_proj_q_scale1, rms_norm14), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape28 = R.call_tir(cls.reshape4, (lv544,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape29 = R.call_tir(cls.reshape5, (reshape28,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv37 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape29), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape30 = R.call_tir(cls.reshape6, (lv37,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape31 = R.call_tir(cls.reshape7, (reshape30,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv545 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_7_mixer_out_proj_q_weight1, transformer_h_7_mixer_out_proj_q_scale1, reshape31), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv540_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv545, lv539_1, transformer_h_7_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv541_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv540_1[1]
            rms_norm15: R.Tensor((1, seq_len, 3072), dtype="float16") = lv540_1[0]
            lv546 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_7_mlp_gate_up_proj_q_weight1, transformer_h_7_mlp_gate_up_proj_q_scale1, rms_norm15), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv203 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv546,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv547 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_7_mlp_down_proj_q_weight1, transformer_h_7_mlp_down_proj_q_scale1, lv203), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv542_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv547, lv541_1, transformer_h_8_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv543_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv542_1[1]
            rms_norm16: R.Tensor((1, seq_len, 3072), dtype="float16") = lv542_1[0]
            lv548 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_8_mixer_qkv_proj_q_weight1, transformer_h_8_mixer_qkv_proj_q_scale1, rms_norm16), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape32 = R.call_tir(cls.reshape4, (lv548,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape33 = R.call_tir(cls.reshape5, (reshape32,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv42 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape33), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape34 = R.call_tir(cls.reshape6, (lv42,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape35 = R.call_tir(cls.reshape7, (reshape34,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv549 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_8_mixer_out_proj_q_weight1, transformer_h_8_mixer_out_proj_q_scale1, reshape35), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv544_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv549, lv543_1, transformer_h_8_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv545_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv544_1[1]
            rms_norm17: R.Tensor((1, seq_len, 3072), dtype="float16") = lv544_1[0]
            lv550 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_8_mlp_gate_up_proj_q_weight1, transformer_h_8_mlp_gate_up_proj_q_scale1, rms_norm17), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv204 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv550,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv551 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_8_mlp_down_proj_q_weight1, transformer_h_8_mlp_down_proj_q_scale1, lv204), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv546_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv551, lv545_1, transformer_h_9_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv547_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv546_1[1]
            rms_norm18: R.Tensor((1, seq_len, 3072), dtype="float16") = lv546_1[0]
            lv552 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_9_mixer_qkv_proj_q_weight1, transformer_h_9_mixer_qkv_proj_q_scale1, rms_norm18), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape36 = R.call_tir(cls.reshape4, (lv552,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape37 = R.call_tir(cls.reshape5, (reshape36,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv47 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape37), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape38 = R.call_tir(cls.reshape6, (lv47,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape39 = R.call_tir(cls.reshape7, (reshape38,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv553 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_9_mixer_out_proj_q_weight1, transformer_h_9_mixer_out_proj_q_scale1, reshape39), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv548_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv553, lv547_1, transformer_h_9_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv549_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv548_1[1]
            rms_norm19: R.Tensor((1, seq_len, 3072), dtype="float16") = lv548_1[0]
            lv554 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_9_mlp_gate_up_proj_q_weight1, transformer_h_9_mlp_gate_up_proj_q_scale1, rms_norm19), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv205 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv554,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv555 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_9_mlp_down_proj_q_weight1, transformer_h_9_mlp_down_proj_q_scale1, lv205), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv550_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv555, lv549_1, transformer_h_10_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv551_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv550_1[1]
            rms_norm20: R.Tensor((1, seq_len, 3072), dtype="float16") = lv550_1[0]
            lv556 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_10_mixer_qkv_proj_q_weight1, transformer_h_10_mixer_qkv_proj_q_scale1, rms_norm20), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape40 = R.call_tir(cls.reshape4, (lv556,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape41 = R.call_tir(cls.reshape5, (reshape40,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv52 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape41), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape42 = R.call_tir(cls.reshape6, (lv52,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape43 = R.call_tir(cls.reshape7, (reshape42,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv557 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_10_mixer_out_proj_q_weight1, transformer_h_10_mixer_out_proj_q_scale1, reshape43), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv552_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv557, lv551_1, transformer_h_10_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv553_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv552_1[1]
            rms_norm21: R.Tensor((1, seq_len, 3072), dtype="float16") = lv552_1[0]
            lv558 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_10_mlp_gate_up_proj_q_weight1, transformer_h_10_mlp_gate_up_proj_q_scale1, rms_norm21), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv206 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv558,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv559 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_10_mlp_down_proj_q_weight1, transformer_h_10_mlp_down_proj_q_scale1, lv206), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv554_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv559, lv553_1, transformer_h_11_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv555_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv554_1[1]
            rms_norm22: R.Tensor((1, seq_len, 3072), dtype="float16") = lv554_1[0]
            lv560 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_11_mixer_qkv_proj_q_weight1, transformer_h_11_mixer_qkv_proj_q_scale1, rms_norm22), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape44 = R.call_tir(cls.reshape4, (lv560,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape45 = R.call_tir(cls.reshape5, (reshape44,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv57 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape45), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape46 = R.call_tir(cls.reshape6, (lv57,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape47 = R.call_tir(cls.reshape7, (reshape46,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv561 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_11_mixer_out_proj_q_weight1, transformer_h_11_mixer_out_proj_q_scale1, reshape47), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv556_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv561, lv555_1, transformer_h_11_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv557_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv556_1[1]
            rms_norm23: R.Tensor((1, seq_len, 3072), dtype="float16") = lv556_1[0]
            lv562 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_11_mlp_gate_up_proj_q_weight1, transformer_h_11_mlp_gate_up_proj_q_scale1, rms_norm23), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv207 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv562,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv563 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_11_mlp_down_proj_q_weight1, transformer_h_11_mlp_down_proj_q_scale1, lv207), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv558_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv563, lv557_1, transformer_h_12_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv559_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv558_1[1]
            rms_norm24: R.Tensor((1, seq_len, 3072), dtype="float16") = lv558_1[0]
            lv564 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_12_mixer_qkv_proj_q_weight1, transformer_h_12_mixer_qkv_proj_q_scale1, rms_norm24), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape48 = R.call_tir(cls.reshape4, (lv564,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape49 = R.call_tir(cls.reshape5, (reshape48,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv62 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape49), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape50 = R.call_tir(cls.reshape6, (lv62,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape51 = R.call_tir(cls.reshape7, (reshape50,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv565 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_12_mixer_out_proj_q_weight1, transformer_h_12_mixer_out_proj_q_scale1, reshape51), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv560_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv565, lv559_1, transformer_h_12_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv561_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv560_1[1]
            rms_norm25: R.Tensor((1, seq_len, 3072), dtype="float16") = lv560_1[0]
            lv566 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_12_mlp_gate_up_proj_q_weight1, transformer_h_12_mlp_gate_up_proj_q_scale1, rms_norm25), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv208 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv566,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv567 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_12_mlp_down_proj_q_weight1, transformer_h_12_mlp_down_proj_q_scale1, lv208), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv562_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv567, lv561_1, transformer_h_13_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv563_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv562_1[1]
            rms_norm26: R.Tensor((1, seq_len, 3072), dtype="float16") = lv562_1[0]
            lv568 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_13_mixer_qkv_proj_q_weight1, transformer_h_13_mixer_qkv_proj_q_scale1, rms_norm26), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape52 = R.call_tir(cls.reshape4, (lv568,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape53 = R.call_tir(cls.reshape5, (reshape52,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv67 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape53), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape54 = R.call_tir(cls.reshape6, (lv67,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape55 = R.call_tir(cls.reshape7, (reshape54,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv569 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_13_mixer_out_proj_q_weight1, transformer_h_13_mixer_out_proj_q_scale1, reshape55), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv564_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv569, lv563_1, transformer_h_13_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv565_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv564_1[1]
            rms_norm27: R.Tensor((1, seq_len, 3072), dtype="float16") = lv564_1[0]
            lv570 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_13_mlp_gate_up_proj_q_weight1, transformer_h_13_mlp_gate_up_proj_q_scale1, rms_norm27), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv209 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv570,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv571 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_13_mlp_down_proj_q_weight1, transformer_h_13_mlp_down_proj_q_scale1, lv209), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv566_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv571, lv565_1, transformer_h_14_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv567_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv566_1[1]
            rms_norm28: R.Tensor((1, seq_len, 3072), dtype="float16") = lv566_1[0]
            lv572 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_14_mixer_qkv_proj_q_weight1, transformer_h_14_mixer_qkv_proj_q_scale1, rms_norm28), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape56 = R.call_tir(cls.reshape4, (lv572,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape57 = R.call_tir(cls.reshape5, (reshape56,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv72 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape57), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape58 = R.call_tir(cls.reshape6, (lv72,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape59 = R.call_tir(cls.reshape7, (reshape58,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv573 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_14_mixer_out_proj_q_weight1, transformer_h_14_mixer_out_proj_q_scale1, reshape59), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv568_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv573, lv567_1, transformer_h_14_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv569_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv568_1[1]
            rms_norm29: R.Tensor((1, seq_len, 3072), dtype="float16") = lv568_1[0]
            lv574 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_14_mlp_gate_up_proj_q_weight1, transformer_h_14_mlp_gate_up_proj_q_scale1, rms_norm29), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv210 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv574,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv575 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_14_mlp_down_proj_q_weight1, transformer_h_14_mlp_down_proj_q_scale1, lv210), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv570_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv575, lv569_1, transformer_h_15_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv571_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv570_1[1]
            rms_norm30: R.Tensor((1, seq_len, 3072), dtype="float16") = lv570_1[0]
            lv576 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_15_mixer_qkv_proj_q_weight1, transformer_h_15_mixer_qkv_proj_q_scale1, rms_norm30), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape60 = R.call_tir(cls.reshape4, (lv576,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape61 = R.call_tir(cls.reshape5, (reshape60,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv77 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape61), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape62 = R.call_tir(cls.reshape6, (lv77,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape63 = R.call_tir(cls.reshape7, (reshape62,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv577 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_15_mixer_out_proj_q_weight1, transformer_h_15_mixer_out_proj_q_scale1, reshape63), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv572_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv577, lv571_1, transformer_h_15_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv573_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv572_1[1]
            rms_norm31: R.Tensor((1, seq_len, 3072), dtype="float16") = lv572_1[0]
            lv578 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_15_mlp_gate_up_proj_q_weight1, transformer_h_15_mlp_gate_up_proj_q_scale1, rms_norm31), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv211 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv578,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv579 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_15_mlp_down_proj_q_weight1, transformer_h_15_mlp_down_proj_q_scale1, lv211), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv574_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv579, lv573_1, transformer_h_16_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv575_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv574_1[1]
            rms_norm32: R.Tensor((1, seq_len, 3072), dtype="float16") = lv574_1[0]
            lv580 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_16_mixer_qkv_proj_q_weight1, transformer_h_16_mixer_qkv_proj_q_scale1, rms_norm32), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape64 = R.call_tir(cls.reshape4, (lv580,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape65 = R.call_tir(cls.reshape5, (reshape64,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv82 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape65), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape66 = R.call_tir(cls.reshape6, (lv82,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape67 = R.call_tir(cls.reshape7, (reshape66,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv581 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_16_mixer_out_proj_q_weight1, transformer_h_16_mixer_out_proj_q_scale1, reshape67), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv576_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv581, lv575_1, transformer_h_16_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv577_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv576_1[1]
            rms_norm33: R.Tensor((1, seq_len, 3072), dtype="float16") = lv576_1[0]
            lv582 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_16_mlp_gate_up_proj_q_weight1, transformer_h_16_mlp_gate_up_proj_q_scale1, rms_norm33), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv212 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv582,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv583 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_16_mlp_down_proj_q_weight1, transformer_h_16_mlp_down_proj_q_scale1, lv212), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv578_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv583, lv577_1, transformer_h_17_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv579_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv578_1[1]
            rms_norm34: R.Tensor((1, seq_len, 3072), dtype="float16") = lv578_1[0]
            lv584 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_17_mixer_qkv_proj_q_weight1, transformer_h_17_mixer_qkv_proj_q_scale1, rms_norm34), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape68 = R.call_tir(cls.reshape4, (lv584,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape69 = R.call_tir(cls.reshape5, (reshape68,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv87 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape69), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape70 = R.call_tir(cls.reshape6, (lv87,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape71 = R.call_tir(cls.reshape7, (reshape70,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv585 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_17_mixer_out_proj_q_weight1, transformer_h_17_mixer_out_proj_q_scale1, reshape71), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv580_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv585, lv579_1, transformer_h_17_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv581_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv580_1[1]
            rms_norm35: R.Tensor((1, seq_len, 3072), dtype="float16") = lv580_1[0]
            lv586 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_17_mlp_gate_up_proj_q_weight1, transformer_h_17_mlp_gate_up_proj_q_scale1, rms_norm35), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv213 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv586,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv587 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_17_mlp_down_proj_q_weight1, transformer_h_17_mlp_down_proj_q_scale1, lv213), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv582_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv587, lv581_1, transformer_h_18_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv583_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv582_1[1]
            rms_norm36: R.Tensor((1, seq_len, 3072), dtype="float16") = lv582_1[0]
            lv588 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_18_mixer_qkv_proj_q_weight1, transformer_h_18_mixer_qkv_proj_q_scale1, rms_norm36), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape72 = R.call_tir(cls.reshape4, (lv588,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape73 = R.call_tir(cls.reshape5, (reshape72,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv92 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape73), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape74 = R.call_tir(cls.reshape6, (lv92,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape75 = R.call_tir(cls.reshape7, (reshape74,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv589 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_18_mixer_out_proj_q_weight1, transformer_h_18_mixer_out_proj_q_scale1, reshape75), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv584_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv589, lv583_1, transformer_h_18_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv585_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv584_1[1]
            rms_norm37: R.Tensor((1, seq_len, 3072), dtype="float16") = lv584_1[0]
            lv590 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_18_mlp_gate_up_proj_q_weight1, transformer_h_18_mlp_gate_up_proj_q_scale1, rms_norm37), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv214 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv590,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv591 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_18_mlp_down_proj_q_weight1, transformer_h_18_mlp_down_proj_q_scale1, lv214), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv586_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv591, lv585_1, transformer_h_19_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv587_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv586_1[1]
            rms_norm38: R.Tensor((1, seq_len, 3072), dtype="float16") = lv586_1[0]
            lv592 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_19_mixer_qkv_proj_q_weight1, transformer_h_19_mixer_qkv_proj_q_scale1, rms_norm38), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape76 = R.call_tir(cls.reshape4, (lv592,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape77 = R.call_tir(cls.reshape5, (reshape76,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv97 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape77), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape78 = R.call_tir(cls.reshape6, (lv97,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape79 = R.call_tir(cls.reshape7, (reshape78,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv593 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_19_mixer_out_proj_q_weight1, transformer_h_19_mixer_out_proj_q_scale1, reshape79), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv588_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv593, lv587_1, transformer_h_19_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv589_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv588_1[1]
            rms_norm39: R.Tensor((1, seq_len, 3072), dtype="float16") = lv588_1[0]
            lv594 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_19_mlp_gate_up_proj_q_weight1, transformer_h_19_mlp_gate_up_proj_q_scale1, rms_norm39), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv215 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv594,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv595 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_19_mlp_down_proj_q_weight1, transformer_h_19_mlp_down_proj_q_scale1, lv215), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv590_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv595, lv589_1, transformer_h_20_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv591_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv590_1[1]
            rms_norm40: R.Tensor((1, seq_len, 3072), dtype="float16") = lv590_1[0]
            lv596 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_20_mixer_qkv_proj_q_weight1, transformer_h_20_mixer_qkv_proj_q_scale1, rms_norm40), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape80 = R.call_tir(cls.reshape4, (lv596,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape81 = R.call_tir(cls.reshape5, (reshape80,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv102 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape81), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape82 = R.call_tir(cls.reshape6, (lv102,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape83 = R.call_tir(cls.reshape7, (reshape82,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv597 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_20_mixer_out_proj_q_weight1, transformer_h_20_mixer_out_proj_q_scale1, reshape83), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv592_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv597, lv591_1, transformer_h_20_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv593_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv592_1[1]
            rms_norm41: R.Tensor((1, seq_len, 3072), dtype="float16") = lv592_1[0]
            lv598 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_20_mlp_gate_up_proj_q_weight1, transformer_h_20_mlp_gate_up_proj_q_scale1, rms_norm41), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv216 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv598,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv599 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_20_mlp_down_proj_q_weight1, transformer_h_20_mlp_down_proj_q_scale1, lv216), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv594_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv599, lv593_1, transformer_h_21_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv595_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv594_1[1]
            rms_norm42: R.Tensor((1, seq_len, 3072), dtype="float16") = lv594_1[0]
            lv600 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_21_mixer_qkv_proj_q_weight1, transformer_h_21_mixer_qkv_proj_q_scale1, rms_norm42), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape84 = R.call_tir(cls.reshape4, (lv600,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape85 = R.call_tir(cls.reshape5, (reshape84,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv107 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape85), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape86 = R.call_tir(cls.reshape6, (lv107,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape87 = R.call_tir(cls.reshape7, (reshape86,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv601 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_21_mixer_out_proj_q_weight1, transformer_h_21_mixer_out_proj_q_scale1, reshape87), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv596_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv601, lv595_1, transformer_h_21_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv597_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv596_1[1]
            rms_norm43: R.Tensor((1, seq_len, 3072), dtype="float16") = lv596_1[0]
            lv602 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_21_mlp_gate_up_proj_q_weight1, transformer_h_21_mlp_gate_up_proj_q_scale1, rms_norm43), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv217 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv602,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv603 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_21_mlp_down_proj_q_weight1, transformer_h_21_mlp_down_proj_q_scale1, lv217), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv598_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv603, lv597_1, transformer_h_22_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv599_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv598_1[1]
            rms_norm44: R.Tensor((1, seq_len, 3072), dtype="float16") = lv598_1[0]
            lv604 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_22_mixer_qkv_proj_q_weight1, transformer_h_22_mixer_qkv_proj_q_scale1, rms_norm44), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape88 = R.call_tir(cls.reshape4, (lv604,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape89 = R.call_tir(cls.reshape5, (reshape88,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv112 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape89), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape90 = R.call_tir(cls.reshape6, (lv112,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape91 = R.call_tir(cls.reshape7, (reshape90,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv605 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_22_mixer_out_proj_q_weight1, transformer_h_22_mixer_out_proj_q_scale1, reshape91), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv600_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv605, lv599_1, transformer_h_22_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv601_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv600_1[1]
            rms_norm45: R.Tensor((1, seq_len, 3072), dtype="float16") = lv600_1[0]
            lv606 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_22_mlp_gate_up_proj_q_weight1, transformer_h_22_mlp_gate_up_proj_q_scale1, rms_norm45), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv218 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv606,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv607 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_22_mlp_down_proj_q_weight1, transformer_h_22_mlp_down_proj_q_scale1, lv218), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv602_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv607, lv601_1, transformer_h_23_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv603_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv602_1[1]
            rms_norm46: R.Tensor((1, seq_len, 3072), dtype="float16") = lv602_1[0]
            lv608 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_23_mixer_qkv_proj_q_weight1, transformer_h_23_mixer_qkv_proj_q_scale1, rms_norm46), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape92 = R.call_tir(cls.reshape4, (lv608,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape93 = R.call_tir(cls.reshape5, (reshape92,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv117 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape93), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape94 = R.call_tir(cls.reshape6, (lv117,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape95 = R.call_tir(cls.reshape7, (reshape94,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv609 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_23_mixer_out_proj_q_weight1, transformer_h_23_mixer_out_proj_q_scale1, reshape95), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv604_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv609, lv603_1, transformer_h_23_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv605_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv604_1[1]
            rms_norm47: R.Tensor((1, seq_len, 3072), dtype="float16") = lv604_1[0]
            lv610 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_23_mlp_gate_up_proj_q_weight1, transformer_h_23_mlp_gate_up_proj_q_scale1, rms_norm47), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv219 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv610,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv611 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_23_mlp_down_proj_q_weight1, transformer_h_23_mlp_down_proj_q_scale1, lv219), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv606_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv611, lv605_1, transformer_h_24_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv607_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv606_1[1]
            rms_norm48: R.Tensor((1, seq_len, 3072), dtype="float16") = lv606_1[0]
            lv612 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_24_mixer_qkv_proj_q_weight1, transformer_h_24_mixer_qkv_proj_q_scale1, rms_norm48), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape96 = R.call_tir(cls.reshape4, (lv612,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape97 = R.call_tir(cls.reshape5, (reshape96,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv122 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape97), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape98 = R.call_tir(cls.reshape6, (lv122,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape99 = R.call_tir(cls.reshape7, (reshape98,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv613 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_24_mixer_out_proj_q_weight1, transformer_h_24_mixer_out_proj_q_scale1, reshape99), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv608_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv613, lv607_1, transformer_h_24_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv609_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv608_1[1]
            rms_norm49: R.Tensor((1, seq_len, 3072), dtype="float16") = lv608_1[0]
            lv614 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_24_mlp_gate_up_proj_q_weight1, transformer_h_24_mlp_gate_up_proj_q_scale1, rms_norm49), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv220 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv614,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv615 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_24_mlp_down_proj_q_weight1, transformer_h_24_mlp_down_proj_q_scale1, lv220), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv610_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv615, lv609_1, transformer_h_25_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv611_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv610_1[1]
            rms_norm50: R.Tensor((1, seq_len, 3072), dtype="float16") = lv610_1[0]
            lv616 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_25_mixer_qkv_proj_q_weight1, transformer_h_25_mixer_qkv_proj_q_scale1, rms_norm50), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape100 = R.call_tir(cls.reshape4, (lv616,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape101 = R.call_tir(cls.reshape5, (reshape100,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv127 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape101), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape102 = R.call_tir(cls.reshape6, (lv127,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape103 = R.call_tir(cls.reshape7, (reshape102,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv617 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_25_mixer_out_proj_q_weight1, transformer_h_25_mixer_out_proj_q_scale1, reshape103), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv612_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv617, lv611_1, transformer_h_25_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv613_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv612_1[1]
            rms_norm51: R.Tensor((1, seq_len, 3072), dtype="float16") = lv612_1[0]
            lv618 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_25_mlp_gate_up_proj_q_weight1, transformer_h_25_mlp_gate_up_proj_q_scale1, rms_norm51), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv221 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv618,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv619 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_25_mlp_down_proj_q_weight1, transformer_h_25_mlp_down_proj_q_scale1, lv221), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv614_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv619, lv613_1, transformer_h_26_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv615_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv614_1[1]
            rms_norm52: R.Tensor((1, seq_len, 3072), dtype="float16") = lv614_1[0]
            lv620 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_26_mixer_qkv_proj_q_weight1, transformer_h_26_mixer_qkv_proj_q_scale1, rms_norm52), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape104 = R.call_tir(cls.reshape4, (lv620,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape105 = R.call_tir(cls.reshape5, (reshape104,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv132 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape105), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape106 = R.call_tir(cls.reshape6, (lv132,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape107 = R.call_tir(cls.reshape7, (reshape106,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv621 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_26_mixer_out_proj_q_weight1, transformer_h_26_mixer_out_proj_q_scale1, reshape107), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv616_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv621, lv615_1, transformer_h_26_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv617_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv616_1[1]
            rms_norm53: R.Tensor((1, seq_len, 3072), dtype="float16") = lv616_1[0]
            lv622 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_26_mlp_gate_up_proj_q_weight1, transformer_h_26_mlp_gate_up_proj_q_scale1, rms_norm53), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv222 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv622,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv623 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_26_mlp_down_proj_q_weight1, transformer_h_26_mlp_down_proj_q_scale1, lv222), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv618_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv623, lv617_1, transformer_h_27_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv619_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv618_1[1]
            rms_norm54: R.Tensor((1, seq_len, 3072), dtype="float16") = lv618_1[0]
            lv624 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_27_mixer_qkv_proj_q_weight1, transformer_h_27_mixer_qkv_proj_q_scale1, rms_norm54), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape108 = R.call_tir(cls.reshape4, (lv624,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape109 = R.call_tir(cls.reshape5, (reshape108,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv137 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape109), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape110 = R.call_tir(cls.reshape6, (lv137,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape111 = R.call_tir(cls.reshape7, (reshape110,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv625 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_27_mixer_out_proj_q_weight1, transformer_h_27_mixer_out_proj_q_scale1, reshape111), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv620_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv625, lv619_1, transformer_h_27_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv621_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv620_1[1]
            rms_norm55: R.Tensor((1, seq_len, 3072), dtype="float16") = lv620_1[0]
            lv626 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_27_mlp_gate_up_proj_q_weight1, transformer_h_27_mlp_gate_up_proj_q_scale1, rms_norm55), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv223 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv626,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv627 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_27_mlp_down_proj_q_weight1, transformer_h_27_mlp_down_proj_q_scale1, lv223), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv622_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv627, lv621_1, transformer_h_28_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv623_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv622_1[1]
            rms_norm56: R.Tensor((1, seq_len, 3072), dtype="float16") = lv622_1[0]
            lv628 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_28_mixer_qkv_proj_q_weight1, transformer_h_28_mixer_qkv_proj_q_scale1, rms_norm56), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape112 = R.call_tir(cls.reshape4, (lv628,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape113 = R.call_tir(cls.reshape5, (reshape112,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv142 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape113), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape114 = R.call_tir(cls.reshape6, (lv142,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape115 = R.call_tir(cls.reshape7, (reshape114,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv629 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_28_mixer_out_proj_q_weight1, transformer_h_28_mixer_out_proj_q_scale1, reshape115), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv624_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv629, lv623_1, transformer_h_28_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv625_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv624_1[1]
            rms_norm57: R.Tensor((1, seq_len, 3072), dtype="float16") = lv624_1[0]
            lv630 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_28_mlp_gate_up_proj_q_weight1, transformer_h_28_mlp_gate_up_proj_q_scale1, rms_norm57), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv224 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv630,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv631 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_28_mlp_down_proj_q_weight1, transformer_h_28_mlp_down_proj_q_scale1, lv224), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv626_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv631, lv625_1, transformer_h_29_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv627_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv626_1[1]
            rms_norm58: R.Tensor((1, seq_len, 3072), dtype="float16") = lv626_1[0]
            lv632 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_29_mixer_qkv_proj_q_weight1, transformer_h_29_mixer_qkv_proj_q_scale1, rms_norm58), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape116 = R.call_tir(cls.reshape4, (lv632,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape117 = R.call_tir(cls.reshape5, (reshape116,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv147 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape117), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape118 = R.call_tir(cls.reshape6, (lv147,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape119 = R.call_tir(cls.reshape7, (reshape118,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv633 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_29_mixer_out_proj_q_weight1, transformer_h_29_mixer_out_proj_q_scale1, reshape119), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv628_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv633, lv627_1, transformer_h_29_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv629_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv628_1[1]
            rms_norm59: R.Tensor((1, seq_len, 3072), dtype="float16") = lv628_1[0]
            lv634 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_29_mlp_gate_up_proj_q_weight1, transformer_h_29_mlp_gate_up_proj_q_scale1, rms_norm59), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv225 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv634,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv635 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_29_mlp_down_proj_q_weight1, transformer_h_29_mlp_down_proj_q_scale1, lv225), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv630_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv635, lv629_1, transformer_h_30_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv631_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv630_1[1]
            rms_norm60: R.Tensor((1, seq_len, 3072), dtype="float16") = lv630_1[0]
            lv636 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_30_mixer_qkv_proj_q_weight1, transformer_h_30_mixer_qkv_proj_q_scale1, rms_norm60), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape120 = R.call_tir(cls.reshape4, (lv636,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape121 = R.call_tir(cls.reshape5, (reshape120,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv152 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape121), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape122 = R.call_tir(cls.reshape6, (lv152,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape123 = R.call_tir(cls.reshape7, (reshape122,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv637 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_30_mixer_out_proj_q_weight1, transformer_h_30_mixer_out_proj_q_scale1, reshape123), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv632_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv637, lv631_1, transformer_h_30_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv633_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv632_1[1]
            rms_norm61: R.Tensor((1, seq_len, 3072), dtype="float16") = lv632_1[0]
            lv638 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_30_mlp_gate_up_proj_q_weight1, transformer_h_30_mlp_gate_up_proj_q_scale1, rms_norm61), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv226 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv638,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv639 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_30_mlp_down_proj_q_weight1, transformer_h_30_mlp_down_proj_q_scale1, lv226), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv634_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv639, lv633_1, transformer_h_31_ln_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv635_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv634_1[1]
            rms_norm62: R.Tensor((1, seq_len, 3072), dtype="float16") = lv634_1[0]
            lv640 = R.call_tir(cls.fused_dequantize1_NT_matmul5, (transformer_h_31_mixer_qkv_proj_q_weight1, transformer_h_31_mixer_qkv_proj_q_scale1, rms_norm62), out_sinfo=R.Tensor((1, seq_len, 9216), dtype="float16"))
            reshape124 = R.call_tir(cls.reshape4, (lv640,), out_sinfo=R.Tensor((1, seq_len, 96, 96), dtype="float16"))
            reshape125 = R.call_tir(cls.reshape5, (reshape124,), out_sinfo=R.Tensor((seq_len, 96, 96), dtype="float16"))
            lv157 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape125), out_sinfo=R.Tensor((seq_len, 32, 96), dtype="float16"))
            reshape126 = R.call_tir(cls.reshape6, (lv157,), out_sinfo=R.Tensor((1, seq_len, 32, 96), dtype="float16"))
            reshape127 = R.call_tir(cls.reshape7, (reshape126,), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv641 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (transformer_h_31_mixer_out_proj_q_weight1, transformer_h_31_mixer_out_proj_q_scale1, reshape127), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv636_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv641, lv635_1, transformer_h_31_post_attention_layernorm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            lv637_1: R.Tensor((1, seq_len, 3072), dtype="float16") = lv636_1[1]
            rms_norm63: R.Tensor((1, seq_len, 3072), dtype="float16") = lv636_1[0]
            lv642 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (transformer_h_31_mlp_gate_up_proj_q_weight1, transformer_h_31_mlp_gate_up_proj_q_scale1, rms_norm63), out_sinfo=R.Tensor((1, seq_len, 16384), dtype="float16"))
            lv227 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv642,), out_sinfo=R.Tensor((1, seq_len, 8192), dtype="float16"))
            lv643 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (transformer_h_31_mlp_down_proj_q_weight1, transformer_h_31_mlp_down_proj_q_scale1, lv227), out_sinfo=R.Tensor((1, seq_len, 3072), dtype="float16"))
            lv638_1 = R.call_tir(cls.fuse_add_norm_prefill, (lv643, lv637_1, transformer_norm_weight1), out_sinfo=[R.Tensor((1, seq_len, 3072), dtype="float16"), R.Tensor((1, seq_len, 3072), dtype="float16")])
            rms_norm64: R.Tensor((1, seq_len, 3072), dtype="float16") = lv638_1[0]
            lv161 = R.call_tir(cls.index, (rms_norm64,), out_sinfo=R.Tensor((1, 1, 3072), dtype="float16"))
            lv644 = R.call_tir(cls.fused_dequantize5_fused_NT_matmul14_cast2, (lm_head_q_weight1, lm_head_q_scale1, lv161), out_sinfo=R.Tensor((1, 1, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((1, 1, vocab_size), dtype="float32"), R.Object) = lv644, paged_kv_cache
            R.output(gv1)
        return gv1

    @R.function
    def renormalize_by_top_p(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), top_p: R.Tensor(("batch_size",), dtype="float32"), init_pivots: R.Tensor(("batch_size", 3), dtype="float32")) -> R.Tensor(("batch_size", "vocab_size"), dtype="float32"):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv6 = R.call_tir(cls.top_p_pivot_cutoff, (probs, top_p, init_pivots), out_sinfo=[R.Tensor((batch_size,), dtype="float32"), R.Tensor((batch_size,), dtype="float32")])
            lv7: R.Tensor((batch_size,), dtype="float32") = lv6[0]
            lv8: R.Tensor((batch_size,), dtype="float32") = lv6[1]
            gv5 = R.call_tir(cls.top_p_renorm_after_cutoff, (probs, lv7, lv8), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            R.output(gv5)
        return gv5

    @R.function
    def sample_with_top_p(sorted_probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), sorted_indices: R.Tensor(("batch_size", "vocab_size"), dtype="int32"), uniform_samples: R.Tensor(("num_samples",), dtype="float32"), sample_indices: R.Tensor(("num_samples",), dtype="int32"), top_p: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("num_samples",), dtype="int32"):
        num_samples = T.int64(is_size_var=True)
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            uniform_samples1: R.Tensor((num_samples, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", uniform_samples, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="float32"),))
            sample_indices1: R.Tensor((num_samples, 1), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", sample_indices, R.shape([num_samples, 1]), sinfo_args=(R.Tensor((num_samples, 1), dtype="int32"),))
            sample_indices2: R.Tensor((batch_size, 1), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", top_p, R.shape([batch_size, 1]), sinfo_args=(R.Tensor((batch_size, 1), dtype="float32"),))
            lv3 = R.call_tir(cls.full, R.tuple(), out_sinfo=R.Tensor((batch_size, 1), dtype="int32"), tir_vars=R.shape([vocab_size]))
            lv1: R.Tensor((8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12,), dtype="uint8") = R.builtin.alloc_tensor(R.shape([8 * (batch_size * vocab_size * 4) + 8388608 + batch_size * vocab_size * 12]), R.dtype("uint8"), R.prim_value(0), R.str("global"))
            cumsum = R.call_tir(cls.cumsum, (sorted_probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            lv4 = R.call_tir(cls.get_renorm_prob, (cumsum, sample_indices2, lv3), out_sinfo=R.Tensor((batch_size, 1), dtype="float32"))
            lv5 = R.call_tir(cls.get_index_from_sorted, (cumsum, sorted_indices, lv4, uniform_samples1, sample_indices1), out_sinfo=R.Tensor((num_samples, 1), dtype="int32"))
            gv2: R.Tensor((num_samples,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv5, R.shape([num_samples]), sinfo_args=(R.Tensor((num_samples,), dtype="int32"),))
            R.output(gv2)
        return gv2

    @R.function
    def sampler_take_probs(unsorted_probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32"), sorted_indices: R.Tensor(("batch_size", "vocab_size"), dtype="int32"), sample_indices: R.Tensor(("num_samples",), dtype="int32"), sampling_result: R.Tensor(("num_samples",), dtype="int32"), lobprob_offsets: R.Tensor(("num_positions",), dtype="int32")) -> R.Tuple(R.Tensor(("num_samples",), dtype="float32"), R.Tensor(("num_positions",), dtype="float32"), R.Tensor(("num_positions",), dtype="int32")):
        num_samples = T.int64(is_size_var=True)
        num_positions = T.int64(is_size_var=True)
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            gv3 = R.call_tir(cls.sampler_take_probs_tir, (unsorted_probs, sorted_indices, sample_indices, sampling_result, lobprob_offsets), out_sinfo=[R.Tensor((num_samples,), dtype="float32"), R.Tensor((num_positions,), dtype="float32"), R.Tensor((num_positions,), dtype="int32")])
            R.output(gv3)
        return gv3

    @R.function
    def sampler_verify_draft_tokens(draft_probs: R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), draft_tokens: R.Tensor(("num_nodes",), dtype="int32"), model_probs: R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), token_tree_first_child: R.Tensor(("num_nodes",), dtype="int32"), token_tree_next_sibling: R.Tensor(("num_nodes",), dtype="int32"), uniform_samples: R.Tensor(("num_nodes",), dtype="float32"), token_tree_parent_ptr: R.Tensor(("nbatch",), dtype="int32")) -> R.Tuple(R.Tensor(("num_nodes", "vocab_size"), dtype="float32"), R.Tensor(("nbatch",), dtype="int32")):
        num_nodes = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        nbatch = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            gv4: R.Tuple(R.Tensor((num_nodes, vocab_size), dtype="float32"), R.Tensor((nbatch,), dtype="int32")) = R.call_tir_inplace(cls.batch_verify_on_gpu_single_kernel, (draft_probs, draft_tokens, model_probs, token_tree_first_child, token_tree_next_sibling, uniform_samples, token_tree_parent_ptr), out_sinfo=[R.Tensor((num_nodes, vocab_size), dtype="float32"), R.Tensor((nbatch,), dtype="int32")], inplace_indices=[2, 6])
            R.output(gv4)
        return gv4

    @R.function
    def softmax_with_temperature(logits: R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"), temperature: R.Tensor(("batch_size",), dtype="float32")) -> R.Tensor(("batch_size", 1, "vocab_size"), dtype="float32"):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 131072}})
        cls = Module
        with R.dataflow():
            lv: R.Tensor((batch_size, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", logits, R.shape([batch_size, vocab_size]), sinfo_args=(R.Tensor((batch_size, vocab_size), dtype="float32"),))
            lv1 = R.call_tir(cls.chunk_lse, (lv, temperature), out_sinfo=[R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32"), R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32")])
            lv2: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[0]
            lv3: R.Tensor((batch_size, (vocab_size + 4096 - 1) // 4096), dtype="float32") = lv1[1]
            lv4 = R.call_tir(cls.softmax_with_chunked_sum, (lv, temperature, lv2, lv3), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv: R.Tensor((batch_size, 1, vocab_size), dtype="float32") = R.call_pure_packed("vm.builtin.reshape", lv4, R.shape([batch_size, 1, vocab_size]), sinfo_args=(R.Tensor((batch_size, 1, vocab_size), dtype="float32"),))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.