# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((128, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 64:j - 64 + 129])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[L_kv_base + cur_L, by, j + 64] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:128, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[cur_L, by, j + 64] * T.float16(-1.0), k[cur_L, by, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func(private=True)
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int64(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                    if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                        if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)))
        temp_max = T.alloc_buffer((batch_size, num_chunks))
        temp_sum = T.alloc_buffer((batch_size, num_chunks))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("pad"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2])
                T.writes(A_pad[v0, v1, v2])
                A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("max"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(A_pad[v0, v1, v2])
                T.writes(temp_max[v0, v1])
                with T.init():
                    temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("sum_exp"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1])
                T.writes(temp_sum[v0, v1])
                with T.init():
                    temp_sum[v0, v1] = T.float32(0.0)
                temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
            with T.block("log"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1])
                T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1])
                chunked_max[v0, v1] = temp_max[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding((batch_size * 256 + 1023) // 1024, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 256
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 128 % 2
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 128
                    if bhd_o * 1024 + bhd_i < batch_size * 2 * 128:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding((copy_length * T.int64(256) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(2, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(128))))
                    vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(128)) // T.int64(128))
                    vd = T.axis.spatial(128, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(128)))
                    T.where(b * T.int64(1024) + T.Cast("int64", t) < copy_length * T.int64(2) * T.int64(128))
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func(private=True)
    def dequantize(model_embed_tokens_q_weight: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale: T.Buffer((T.int64(151936), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(151936), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_embed_tokens_q_scale[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_embed_tokens_q_scale[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize1(model_layers_0_self_attn_c_attn_q_weight1: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale1: T.Buffer((T.int64(2560), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(2560), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize2(model_layers_0_self_attn_o_proj_q_weight1: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale1: T.Buffer((T.int64(2048), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize3(model_layers_0_mlp_gate_up_proj_q_weight1: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale1: T.Buffer((T.int64(22016), T.int64(64)), "float16"), dequantize: T.Buffer((T.int64(22016), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_gate_up_proj_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_gate_up_proj_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func(private=True)
    def dequantize4(model_layers_0_mlp_down_proj_q_weight1: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale1: T.Buffer((T.int64(2048), T.int64(344)), "float16"), dequantize: T.Buffer((T.int64(2048), T.int64(11008)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_down_proj_q_weight1[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight1[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_down_proj_q_scale1[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize[v_i0, v_i1])
                dequantize[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale1[v_i0, v_i1 // T.int64(32)]

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for i in range(batch_size):
            with T.block("block"):
                vi = T.axis.spatial(batch_size, i)
                T.reads()
                T.writes(result[vi, 0])
                result[vi, 0] = value

    @T.prim_func
    def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32):
        T.func_attr({"op_pattern": 8, "target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        qkv = T.match_buffer(var_qkv, (seq_len, 20, 128), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 16, 128), "float16")
        k = T.match_buffer(var_k, (seq_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (seq_len, 2, 128), "float16")
        # with T.block("root"):
        for iters_0, iters_1, iters_2 in T.grid(seq_len, 20, 128):
            with T.block("llama_fused_rope"):
                s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2])
                T.reads(position_map[s], qkv[s, h, d - 64:d - 64 + 129])
                T.writes(q[s, h, d], k[s, h - 16, d], v[s, h - 18, d])
                if h < 16:
                    freq = T.float32()
                    q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                else:
                    if h < 18:
                        freq = T.float32()
                        k[s, h - 16, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                    else:
                        v[s, h - 18, d] = qkv[s, h, d]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("gather_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[indices[vb], vj], indices[vb])
                T.writes(dst[vb, vj])
                dst[vb, vj] = src[indices[vb], vj]

    @T.prim_func(private=True)
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int64(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0, ax1 in T.grid(out_batch, vocab_size):
            with T.block("T_get_index_from_sorted"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
                T.writes(output_index[v_ax0, 0])
                if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size:
                    if v_ax1 == T.int64(0):
                        output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
                    else:
                        if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
                            output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

    @T.prim_func(private=True)
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch, vocab_size):
            with T.block("T_get_renorm_prob"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0])
                T.writes(renorm_prob[v_ax0, 0])
                if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > 1):
                    renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
                else:
                    if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0]):
                        if v_ax1 + T.int64(1) == vocab_size:
                            renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
                        else:
                            if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0])):
                                renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]

    @T.prim_func(private=True)
    def index(var_rms_norm72: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm72 = T.match_buffer(var_rms_norm72, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("index"):
                v_i, v__, v_k = T.axis.remap("SSS", [i, _, k])
                T.reads(rms_norm72[v_i, seq_len - T.int64(1), v_k])
                T.writes(index[v_i, v__, v_k])
                index[v_i, v__, v_k] = rms_norm72[v_i, seq_len - T.int64(1), v_k]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(16, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 16], S_other[bx, ty + by * 16], V[bx, ty + by * 16, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 16, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 16, tx * 4:tx * 4 + 4], S[bx, ty + by * 16])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 16]
                            s_other_val[0] = S_other[bx, ty + by * 16]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 16, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 16, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 16, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 16] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1})})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for i in range(num_positions + num_samples):
            with T.block("block"):
                vi = T.axis.spatial(num_positions + num_samples, i)
                T.reads(top_prob_offsets[vi], sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], unsorted_probs[T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]):T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + (T.max(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + 1 - T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions])), T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]):T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + (T.max(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]))], sample_indices[vi - num_positions], sampling_results[vi - num_positions])
                T.writes(top_prob_indices[vi], top_prob_probs[vi], sampled_values[vi - num_positions])
                if vi < num_positions:
                    row: T.int32 = top_prob_offsets[vi] // vocab_size
                    col: T.int32 = top_prob_offsets[vi] % vocab_size
                    top_prob_indices[vi] = sorted_indices[row, col]
                    top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]]
                else:
                    vj: T.int32 = vi - num_positions
                    sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("scatter_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[vb, vj], indices[vb])
                T.writes(dst[indices[vb], vj])
                dst[indices[vb], vj] = src[vb, vj]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * T.int64(4096) + v2])
                            if v1 * T.int64(4096) + v2 < vocab_size:
                                softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func(private=True)
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int64(is_size_var=True), T.int64(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for i, j in T.grid(batch_size_1, vocab_size_1):
            with T.block("take_sorted_probs"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(probs[v_i, lv1[v_i, v_j]], lv1[v_i, v_j])
                T.writes(take_sorted_probs[v_i, v_j])
                take_sorted_probs[v_i, v_j] = probs[v_i, lv1[v_i, v_j]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int64(), T.int64(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        seqlen = T.int64(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (36, seqlen, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (36, seqlen, 2, 128), "float16")
        # with T.block("root"):
        for p, h, d in T.grid(seqlen, 2, 128):
            with T.block("copy0"):
                vp, vh, vd = T.axis.remap("SSS", [p, h, d])
                T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd])
                T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                position: T.int32 = position_map[vp]
                k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd]
                v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.noalias": T.bool(True)})
        num_pages = T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        ntoken = T.int64(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 2, 128), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos, h, f in T.grid(ntoken, 2, 128):
            if position_map[global_pos] != -1:
                with T.block("k_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                with T.block("v_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func(private=True)
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func(private=True)
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["vulkan", "gpu"], "kind": "vulkan", "max_num_threads": 256, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "supports_16bit_buffer": True, "supports_8bit_buffer": True, "supports_float16": True, "supports_float32": True, "supports_int16": True, "supports_int32": True, "supports_int64": True, "supports_int8": True, "supports_storage_buffer_storage_class": True, "tag": "", "thread_warp_size": 1}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"):
        R.func_attr({"relax.memory_plan_dynamic_func_output": True})
        gv: R.Tensor((2048, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global"))
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv1: R.Tensor((batch_size, vocab_size), dtype="int32") = R.argsort(probs, axis=-1, descending=True, dtype="int32")
            lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1
            R.output(gv1)
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight4: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale4: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm219: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(input_embeds, model_layers_0_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv545 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight4, model_layers_0_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims435: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv545, axes=None)
            matmul435: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm219, permute_dims435, out_dtype="void")
            add324: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul435, model_layers_0_self_attn_c_attn_bias4)
            reshape432: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add324, R.shape([batch_size, 1, 20, 128]))
            reshape433: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape432, R.shape([batch_size, 20, 128]))
            lv546 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape433), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape434: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv546, R.shape([batch_size, 1, 16, 128]))
            reshape435: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape434, R.shape([batch_size, 1, 2048]))
            lv547 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight4, model_layers_0_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims436: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv547, axes=None)
            matmul436: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape435, permute_dims436, out_dtype="void")
            add325: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul436, input_embeds)
            rms_norm220: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add325, model_layers_0_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv548 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight4, model_layers_0_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims437: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv548, axes=None)
            matmul437: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm220, permute_dims437, out_dtype="void")
            split108: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul437, indices_or_sections=2, axis=-1)
            split_0108: R.Tensor((batch_size, 1, 11008), dtype="float16") = split108[0]
            split_1108: R.Tensor((batch_size, 1, 11008), dtype="float16") = split108[1]
            silu108: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0108)
            mul108: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu108, split_1108)
            lv549 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight4, model_layers_0_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims438: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv549, axes=None)
            matmul438: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul108, permute_dims438, out_dtype="void")
            add326: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul438, add325)
            rms_norm221: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add326, model_layers_1_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv550 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight4, model_layers_1_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims439: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv550, axes=None)
            matmul439: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm221, permute_dims439, out_dtype="void")
            add327: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul439, model_layers_1_self_attn_c_attn_bias4)
            reshape436: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add327, R.shape([batch_size, 1, 20, 128]))
            reshape437: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape436, R.shape([batch_size, 20, 128]))
            lv551 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape437), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape438: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv551, R.shape([batch_size, 1, 16, 128]))
            reshape439: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape438, R.shape([batch_size, 1, 2048]))
            lv552 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight4, model_layers_1_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims440: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv552, axes=None)
            matmul440: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape439, permute_dims440, out_dtype="void")
            add328: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul440, add326)
            rms_norm222: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add328, model_layers_1_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv553 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight4, model_layers_1_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims441: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv553, axes=None)
            matmul441: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm222, permute_dims441, out_dtype="void")
            split109: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul441, indices_or_sections=2, axis=-1)
            split_0109: R.Tensor((batch_size, 1, 11008), dtype="float16") = split109[0]
            split_1109: R.Tensor((batch_size, 1, 11008), dtype="float16") = split109[1]
            silu109: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0109)
            mul109: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu109, split_1109)
            lv554 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight4, model_layers_1_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims442: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv554, axes=None)
            matmul442: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul109, permute_dims442, out_dtype="void")
            add329: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul442, add328)
            rms_norm223: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add329, model_layers_2_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv555 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight4, model_layers_2_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims443: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv555, axes=None)
            matmul443: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm223, permute_dims443, out_dtype="void")
            add330: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul443, model_layers_2_self_attn_c_attn_bias4)
            reshape440: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add330, R.shape([batch_size, 1, 20, 128]))
            reshape441: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape440, R.shape([batch_size, 20, 128]))
            lv556 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape441), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape442: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv556, R.shape([batch_size, 1, 16, 128]))
            reshape443: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape442, R.shape([batch_size, 1, 2048]))
            lv557 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight4, model_layers_2_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims444: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv557, axes=None)
            matmul444: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape443, permute_dims444, out_dtype="void")
            add331: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul444, add329)
            rms_norm224: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add331, model_layers_2_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv558 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight4, model_layers_2_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims445: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv558, axes=None)
            matmul445: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm224, permute_dims445, out_dtype="void")
            split110: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul445, indices_or_sections=2, axis=-1)
            split_0110: R.Tensor((batch_size, 1, 11008), dtype="float16") = split110[0]
            split_1110: R.Tensor((batch_size, 1, 11008), dtype="float16") = split110[1]
            silu110: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0110)
            mul110: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu110, split_1110)
            lv559 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight4, model_layers_2_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims446: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv559, axes=None)
            matmul446: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul110, permute_dims446, out_dtype="void")
            add332: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul446, add331)
            rms_norm225: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add332, model_layers_3_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv560 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight4, model_layers_3_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims447: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv560, axes=None)
            matmul447: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm225, permute_dims447, out_dtype="void")
            add333: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul447, model_layers_3_self_attn_c_attn_bias4)
            reshape444: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add333, R.shape([batch_size, 1, 20, 128]))
            reshape445: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape444, R.shape([batch_size, 20, 128]))
            lv561 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape445), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape446: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv561, R.shape([batch_size, 1, 16, 128]))
            reshape447: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape446, R.shape([batch_size, 1, 2048]))
            lv562 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight4, model_layers_3_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims448: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv562, axes=None)
            matmul448: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape447, permute_dims448, out_dtype="void")
            add334: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul448, add332)
            rms_norm226: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add334, model_layers_3_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv563 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight4, model_layers_3_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims449: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv563, axes=None)
            matmul449: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm226, permute_dims449, out_dtype="void")
            split111: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul449, indices_or_sections=2, axis=-1)
            split_0111: R.Tensor((batch_size, 1, 11008), dtype="float16") = split111[0]
            split_1111: R.Tensor((batch_size, 1, 11008), dtype="float16") = split111[1]
            silu111: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0111)
            mul111: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu111, split_1111)
            lv564 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight4, model_layers_3_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims450: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv564, axes=None)
            matmul450: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul111, permute_dims450, out_dtype="void")
            add335: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul450, add334)
            rms_norm227: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add335, model_layers_4_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv565 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight4, model_layers_4_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims451: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv565, axes=None)
            matmul451: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm227, permute_dims451, out_dtype="void")
            add336: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul451, model_layers_4_self_attn_c_attn_bias4)
            reshape448: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add336, R.shape([batch_size, 1, 20, 128]))
            reshape449: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape448, R.shape([batch_size, 20, 128]))
            lv566 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape449), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape450: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv566, R.shape([batch_size, 1, 16, 128]))
            reshape451: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape450, R.shape([batch_size, 1, 2048]))
            lv567 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight4, model_layers_4_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims452: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv567, axes=None)
            matmul452: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape451, permute_dims452, out_dtype="void")
            add337: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul452, add335)
            rms_norm228: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add337, model_layers_4_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv568 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight4, model_layers_4_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims453: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv568, axes=None)
            matmul453: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm228, permute_dims453, out_dtype="void")
            split112: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul453, indices_or_sections=2, axis=-1)
            split_0112: R.Tensor((batch_size, 1, 11008), dtype="float16") = split112[0]
            split_1112: R.Tensor((batch_size, 1, 11008), dtype="float16") = split112[1]
            silu112: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0112)
            mul112: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu112, split_1112)
            lv569 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight4, model_layers_4_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims454: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv569, axes=None)
            matmul454: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul112, permute_dims454, out_dtype="void")
            add338: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul454, add337)
            rms_norm229: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add338, model_layers_5_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv570 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight4, model_layers_5_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims455: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv570, axes=None)
            matmul455: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm229, permute_dims455, out_dtype="void")
            add339: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul455, model_layers_5_self_attn_c_attn_bias4)
            reshape452: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add339, R.shape([batch_size, 1, 20, 128]))
            reshape453: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape452, R.shape([batch_size, 20, 128]))
            lv571 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape453), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape454: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv571, R.shape([batch_size, 1, 16, 128]))
            reshape455: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape454, R.shape([batch_size, 1, 2048]))
            lv572 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight4, model_layers_5_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims456: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv572, axes=None)
            matmul456: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape455, permute_dims456, out_dtype="void")
            add340: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul456, add338)
            rms_norm230: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add340, model_layers_5_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv573 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight4, model_layers_5_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims457: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv573, axes=None)
            matmul457: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm230, permute_dims457, out_dtype="void")
            split113: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul457, indices_or_sections=2, axis=-1)
            split_0113: R.Tensor((batch_size, 1, 11008), dtype="float16") = split113[0]
            split_1113: R.Tensor((batch_size, 1, 11008), dtype="float16") = split113[1]
            silu113: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0113)
            mul113: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu113, split_1113)
            lv574 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight4, model_layers_5_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims458: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv574, axes=None)
            matmul458: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul113, permute_dims458, out_dtype="void")
            add341: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul458, add340)
            rms_norm231: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add341, model_layers_6_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv575 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight4, model_layers_6_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims459: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv575, axes=None)
            matmul459: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm231, permute_dims459, out_dtype="void")
            add342: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul459, model_layers_6_self_attn_c_attn_bias4)
            reshape456: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add342, R.shape([batch_size, 1, 20, 128]))
            reshape457: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape456, R.shape([batch_size, 20, 128]))
            lv576 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape457), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape458: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv576, R.shape([batch_size, 1, 16, 128]))
            reshape459: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape458, R.shape([batch_size, 1, 2048]))
            lv577 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight4, model_layers_6_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims460: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv577, axes=None)
            matmul460: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape459, permute_dims460, out_dtype="void")
            add343: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul460, add341)
            rms_norm232: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add343, model_layers_6_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv578 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight4, model_layers_6_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims461: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv578, axes=None)
            matmul461: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm232, permute_dims461, out_dtype="void")
            split114: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul461, indices_or_sections=2, axis=-1)
            split_0114: R.Tensor((batch_size, 1, 11008), dtype="float16") = split114[0]
            split_1114: R.Tensor((batch_size, 1, 11008), dtype="float16") = split114[1]
            silu114: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0114)
            mul114: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu114, split_1114)
            lv579 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight4, model_layers_6_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims462: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv579, axes=None)
            matmul462: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul114, permute_dims462, out_dtype="void")
            add344: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul462, add343)
            rms_norm233: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add344, model_layers_7_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv580 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight4, model_layers_7_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims463: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv580, axes=None)
            matmul463: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm233, permute_dims463, out_dtype="void")
            add345: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul463, model_layers_7_self_attn_c_attn_bias4)
            reshape460: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add345, R.shape([batch_size, 1, 20, 128]))
            reshape461: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape460, R.shape([batch_size, 20, 128]))
            lv581 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape461), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape462: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv581, R.shape([batch_size, 1, 16, 128]))
            reshape463: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape462, R.shape([batch_size, 1, 2048]))
            lv582 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight4, model_layers_7_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims464: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv582, axes=None)
            matmul464: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape463, permute_dims464, out_dtype="void")
            add346: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul464, add344)
            rms_norm234: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add346, model_layers_7_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv583 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight4, model_layers_7_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims465: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv583, axes=None)
            matmul465: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm234, permute_dims465, out_dtype="void")
            split115: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul465, indices_or_sections=2, axis=-1)
            split_0115: R.Tensor((batch_size, 1, 11008), dtype="float16") = split115[0]
            split_1115: R.Tensor((batch_size, 1, 11008), dtype="float16") = split115[1]
            silu115: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0115)
            mul115: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu115, split_1115)
            lv584 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight4, model_layers_7_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims466: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv584, axes=None)
            matmul466: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul115, permute_dims466, out_dtype="void")
            add347: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul466, add346)
            rms_norm235: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add347, model_layers_8_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv585 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight4, model_layers_8_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims467: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv585, axes=None)
            matmul467: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm235, permute_dims467, out_dtype="void")
            add348: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul467, model_layers_8_self_attn_c_attn_bias4)
            reshape464: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add348, R.shape([batch_size, 1, 20, 128]))
            reshape465: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape464, R.shape([batch_size, 20, 128]))
            lv586 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape465), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape466: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv586, R.shape([batch_size, 1, 16, 128]))
            reshape467: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape466, R.shape([batch_size, 1, 2048]))
            lv587 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight4, model_layers_8_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims468: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv587, axes=None)
            matmul468: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape467, permute_dims468, out_dtype="void")
            add349: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul468, add347)
            rms_norm236: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add349, model_layers_8_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv588 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight4, model_layers_8_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims469: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv588, axes=None)
            matmul469: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm236, permute_dims469, out_dtype="void")
            split116: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul469, indices_or_sections=2, axis=-1)
            split_0116: R.Tensor((batch_size, 1, 11008), dtype="float16") = split116[0]
            split_1116: R.Tensor((batch_size, 1, 11008), dtype="float16") = split116[1]
            silu116: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0116)
            mul116: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu116, split_1116)
            lv589 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight4, model_layers_8_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims470: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv589, axes=None)
            matmul470: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul116, permute_dims470, out_dtype="void")
            add350: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul470, add349)
            rms_norm237: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add350, model_layers_9_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv590 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight4, model_layers_9_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims471: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv590, axes=None)
            matmul471: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm237, permute_dims471, out_dtype="void")
            add351: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul471, model_layers_9_self_attn_c_attn_bias4)
            reshape468: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add351, R.shape([batch_size, 1, 20, 128]))
            reshape469: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape468, R.shape([batch_size, 20, 128]))
            lv591 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape469), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape470: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv591, R.shape([batch_size, 1, 16, 128]))
            reshape471: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape470, R.shape([batch_size, 1, 2048]))
            lv592 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight4, model_layers_9_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims472: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv592, axes=None)
            matmul472: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape471, permute_dims472, out_dtype="void")
            add352: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul472, add350)
            rms_norm238: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add352, model_layers_9_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv593 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight4, model_layers_9_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims473: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv593, axes=None)
            matmul473: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm238, permute_dims473, out_dtype="void")
            split117: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul473, indices_or_sections=2, axis=-1)
            split_0117: R.Tensor((batch_size, 1, 11008), dtype="float16") = split117[0]
            split_1117: R.Tensor((batch_size, 1, 11008), dtype="float16") = split117[1]
            silu117: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0117)
            mul117: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu117, split_1117)
            lv594 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight4, model_layers_9_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims474: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv594, axes=None)
            matmul474: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul117, permute_dims474, out_dtype="void")
            add353: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul474, add352)
            rms_norm239: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add353, model_layers_10_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv595 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight4, model_layers_10_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims475: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv595, axes=None)
            matmul475: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm239, permute_dims475, out_dtype="void")
            add354: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul475, model_layers_10_self_attn_c_attn_bias4)
            reshape472: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add354, R.shape([batch_size, 1, 20, 128]))
            reshape473: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape472, R.shape([batch_size, 20, 128]))
            lv596 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape473), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape474: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv596, R.shape([batch_size, 1, 16, 128]))
            reshape475: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape474, R.shape([batch_size, 1, 2048]))
            lv597 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight4, model_layers_10_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims476: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv597, axes=None)
            matmul476: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape475, permute_dims476, out_dtype="void")
            add355: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul476, add353)
            rms_norm240: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add355, model_layers_10_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv598 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight4, model_layers_10_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims477: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv598, axes=None)
            matmul477: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm240, permute_dims477, out_dtype="void")
            split118: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul477, indices_or_sections=2, axis=-1)
            split_0118: R.Tensor((batch_size, 1, 11008), dtype="float16") = split118[0]
            split_1118: R.Tensor((batch_size, 1, 11008), dtype="float16") = split118[1]
            silu118: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0118)
            mul118: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu118, split_1118)
            lv599 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight4, model_layers_10_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims478: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv599, axes=None)
            matmul478: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul118, permute_dims478, out_dtype="void")
            add356: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul478, add355)
            rms_norm241: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add356, model_layers_11_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv600 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight4, model_layers_11_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims479: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv600, axes=None)
            matmul479: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm241, permute_dims479, out_dtype="void")
            add357: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul479, model_layers_11_self_attn_c_attn_bias4)
            reshape476: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add357, R.shape([batch_size, 1, 20, 128]))
            reshape477: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape476, R.shape([batch_size, 20, 128]))
            lv601 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape477), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape478: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv601, R.shape([batch_size, 1, 16, 128]))
            reshape479: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape478, R.shape([batch_size, 1, 2048]))
            lv602 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight4, model_layers_11_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims480: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv602, axes=None)
            matmul480: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape479, permute_dims480, out_dtype="void")
            add358: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul480, add356)
            rms_norm242: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add358, model_layers_11_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv603 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight4, model_layers_11_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims481: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv603, axes=None)
            matmul481: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm242, permute_dims481, out_dtype="void")
            split119: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul481, indices_or_sections=2, axis=-1)
            split_0119: R.Tensor((batch_size, 1, 11008), dtype="float16") = split119[0]
            split_1119: R.Tensor((batch_size, 1, 11008), dtype="float16") = split119[1]
            silu119: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0119)
            mul119: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu119, split_1119)
            lv604 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight4, model_layers_11_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims482: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv604, axes=None)
            matmul482: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul119, permute_dims482, out_dtype="void")
            add359: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul482, add358)
            rms_norm243: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add359, model_layers_12_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv605 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight4, model_layers_12_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims483: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv605, axes=None)
            matmul483: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm243, permute_dims483, out_dtype="void")
            add360: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul483, model_layers_12_self_attn_c_attn_bias4)
            reshape480: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add360, R.shape([batch_size, 1, 20, 128]))
            reshape481: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape480, R.shape([batch_size, 20, 128]))
            lv606 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape481), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape482: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv606, R.shape([batch_size, 1, 16, 128]))
            reshape483: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape482, R.shape([batch_size, 1, 2048]))
            lv607 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight4, model_layers_12_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims484: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv607, axes=None)
            matmul484: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape483, permute_dims484, out_dtype="void")
            add361: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul484, add359)
            rms_norm244: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add361, model_layers_12_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv608 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight4, model_layers_12_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims485: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv608, axes=None)
            matmul485: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm244, permute_dims485, out_dtype="void")
            split120: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul485, indices_or_sections=2, axis=-1)
            split_0120: R.Tensor((batch_size, 1, 11008), dtype="float16") = split120[0]
            split_1120: R.Tensor((batch_size, 1, 11008), dtype="float16") = split120[1]
            silu120: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0120)
            mul120: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu120, split_1120)
            lv609 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight4, model_layers_12_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims486: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv609, axes=None)
            matmul486: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul120, permute_dims486, out_dtype="void")
            add362: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul486, add361)
            rms_norm245: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add362, model_layers_13_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv610 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight4, model_layers_13_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims487: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv610, axes=None)
            matmul487: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm245, permute_dims487, out_dtype="void")
            add363: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul487, model_layers_13_self_attn_c_attn_bias4)
            reshape484: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add363, R.shape([batch_size, 1, 20, 128]))
            reshape485: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape484, R.shape([batch_size, 20, 128]))
            lv611 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape485), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape486: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv611, R.shape([batch_size, 1, 16, 128]))
            reshape487: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape486, R.shape([batch_size, 1, 2048]))
            lv612 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight4, model_layers_13_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims488: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv612, axes=None)
            matmul488: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape487, permute_dims488, out_dtype="void")
            add364: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul488, add362)
            rms_norm246: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add364, model_layers_13_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv613 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight4, model_layers_13_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims489: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv613, axes=None)
            matmul489: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm246, permute_dims489, out_dtype="void")
            split121: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul489, indices_or_sections=2, axis=-1)
            split_0121: R.Tensor((batch_size, 1, 11008), dtype="float16") = split121[0]
            split_1121: R.Tensor((batch_size, 1, 11008), dtype="float16") = split121[1]
            silu121: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0121)
            mul121: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu121, split_1121)
            lv614 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight4, model_layers_13_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims490: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv614, axes=None)
            matmul490: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul121, permute_dims490, out_dtype="void")
            add365: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul490, add364)
            rms_norm247: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add365, model_layers_14_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv615 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight4, model_layers_14_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims491: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv615, axes=None)
            matmul491: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm247, permute_dims491, out_dtype="void")
            add366: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul491, model_layers_14_self_attn_c_attn_bias4)
            reshape488: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add366, R.shape([batch_size, 1, 20, 128]))
            reshape489: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape488, R.shape([batch_size, 20, 128]))
            lv616 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape489), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape490: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv616, R.shape([batch_size, 1, 16, 128]))
            reshape491: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape490, R.shape([batch_size, 1, 2048]))
            lv617 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight4, model_layers_14_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims492: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv617, axes=None)
            matmul492: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape491, permute_dims492, out_dtype="void")
            add367: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul492, add365)
            rms_norm248: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add367, model_layers_14_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv618 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight4, model_layers_14_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims493: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv618, axes=None)
            matmul493: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm248, permute_dims493, out_dtype="void")
            split122: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul493, indices_or_sections=2, axis=-1)
            split_0122: R.Tensor((batch_size, 1, 11008), dtype="float16") = split122[0]
            split_1122: R.Tensor((batch_size, 1, 11008), dtype="float16") = split122[1]
            silu122: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0122)
            mul122: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu122, split_1122)
            lv619 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight4, model_layers_14_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims494: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv619, axes=None)
            matmul494: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul122, permute_dims494, out_dtype="void")
            add368: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul494, add367)
            rms_norm249: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add368, model_layers_15_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv620 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight4, model_layers_15_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims495: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv620, axes=None)
            matmul495: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm249, permute_dims495, out_dtype="void")
            add369: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul495, model_layers_15_self_attn_c_attn_bias4)
            reshape492: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add369, R.shape([batch_size, 1, 20, 128]))
            reshape493: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape492, R.shape([batch_size, 20, 128]))
            lv621 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape493), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape494: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv621, R.shape([batch_size, 1, 16, 128]))
            reshape495: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape494, R.shape([batch_size, 1, 2048]))
            lv622 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight4, model_layers_15_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims496: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv622, axes=None)
            matmul496: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape495, permute_dims496, out_dtype="void")
            add370: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul496, add368)
            rms_norm250: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add370, model_layers_15_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv623 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight4, model_layers_15_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims497: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv623, axes=None)
            matmul497: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm250, permute_dims497, out_dtype="void")
            split123: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul497, indices_or_sections=2, axis=-1)
            split_0123: R.Tensor((batch_size, 1, 11008), dtype="float16") = split123[0]
            split_1123: R.Tensor((batch_size, 1, 11008), dtype="float16") = split123[1]
            silu123: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0123)
            mul123: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu123, split_1123)
            lv624 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight4, model_layers_15_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims498: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv624, axes=None)
            matmul498: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul123, permute_dims498, out_dtype="void")
            add371: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul498, add370)
            rms_norm251: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add371, model_layers_16_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv625 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight4, model_layers_16_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims499: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv625, axes=None)
            matmul499: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm251, permute_dims499, out_dtype="void")
            add372: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul499, model_layers_16_self_attn_c_attn_bias4)
            reshape496: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add372, R.shape([batch_size, 1, 20, 128]))
            reshape497: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape496, R.shape([batch_size, 20, 128]))
            lv626 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape497), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape498: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv626, R.shape([batch_size, 1, 16, 128]))
            reshape499: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape498, R.shape([batch_size, 1, 2048]))
            lv627 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight4, model_layers_16_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims500: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv627, axes=None)
            matmul500: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape499, permute_dims500, out_dtype="void")
            add373: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul500, add371)
            rms_norm252: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add373, model_layers_16_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv628 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight4, model_layers_16_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims501: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv628, axes=None)
            matmul501: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm252, permute_dims501, out_dtype="void")
            split124: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul501, indices_or_sections=2, axis=-1)
            split_0124: R.Tensor((batch_size, 1, 11008), dtype="float16") = split124[0]
            split_1124: R.Tensor((batch_size, 1, 11008), dtype="float16") = split124[1]
            silu124: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0124)
            mul124: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu124, split_1124)
            lv629 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight4, model_layers_16_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims502: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv629, axes=None)
            matmul502: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul124, permute_dims502, out_dtype="void")
            add374: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul502, add373)
            rms_norm253: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add374, model_layers_17_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv630 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight4, model_layers_17_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims503: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv630, axes=None)
            matmul503: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm253, permute_dims503, out_dtype="void")
            add375: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul503, model_layers_17_self_attn_c_attn_bias4)
            reshape500: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add375, R.shape([batch_size, 1, 20, 128]))
            reshape501: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape500, R.shape([batch_size, 20, 128]))
            lv631 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape501), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape502: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv631, R.shape([batch_size, 1, 16, 128]))
            reshape503: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape502, R.shape([batch_size, 1, 2048]))
            lv632 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight4, model_layers_17_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims504: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv632, axes=None)
            matmul504: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape503, permute_dims504, out_dtype="void")
            add376: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul504, add374)
            rms_norm254: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add376, model_layers_17_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv633 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight4, model_layers_17_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims505: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv633, axes=None)
            matmul505: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm254, permute_dims505, out_dtype="void")
            split125: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul505, indices_or_sections=2, axis=-1)
            split_0125: R.Tensor((batch_size, 1, 11008), dtype="float16") = split125[0]
            split_1125: R.Tensor((batch_size, 1, 11008), dtype="float16") = split125[1]
            silu125: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0125)
            mul125: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu125, split_1125)
            lv634 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight4, model_layers_17_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims506: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv634, axes=None)
            matmul506: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul125, permute_dims506, out_dtype="void")
            add377: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul506, add376)
            rms_norm255: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add377, model_layers_18_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv635 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight4, model_layers_18_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims507: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv635, axes=None)
            matmul507: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm255, permute_dims507, out_dtype="void")
            add378: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul507, model_layers_18_self_attn_c_attn_bias4)
            reshape504: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add378, R.shape([batch_size, 1, 20, 128]))
            reshape505: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape504, R.shape([batch_size, 20, 128]))
            lv636 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape505), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape506: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv636, R.shape([batch_size, 1, 16, 128]))
            reshape507: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape506, R.shape([batch_size, 1, 2048]))
            lv637 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight4, model_layers_18_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims508: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv637, axes=None)
            matmul508: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape507, permute_dims508, out_dtype="void")
            add379: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul508, add377)
            rms_norm256: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add379, model_layers_18_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv638 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight4, model_layers_18_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims509: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv638, axes=None)
            matmul509: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm256, permute_dims509, out_dtype="void")
            split126: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul509, indices_or_sections=2, axis=-1)
            split_0126: R.Tensor((batch_size, 1, 11008), dtype="float16") = split126[0]
            split_1126: R.Tensor((batch_size, 1, 11008), dtype="float16") = split126[1]
            silu126: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0126)
            mul126: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu126, split_1126)
            lv639 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight4, model_layers_18_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims510: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv639, axes=None)
            matmul510: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul126, permute_dims510, out_dtype="void")
            add380: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul510, add379)
            rms_norm257: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add380, model_layers_19_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv640 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight4, model_layers_19_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims511: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv640, axes=None)
            matmul511: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm257, permute_dims511, out_dtype="void")
            add381: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul511, model_layers_19_self_attn_c_attn_bias4)
            reshape508: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add381, R.shape([batch_size, 1, 20, 128]))
            reshape509: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape508, R.shape([batch_size, 20, 128]))
            lv641 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape509), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape510: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv641, R.shape([batch_size, 1, 16, 128]))
            reshape511: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape510, R.shape([batch_size, 1, 2048]))
            lv642 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight4, model_layers_19_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims512: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv642, axes=None)
            matmul512: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape511, permute_dims512, out_dtype="void")
            add382: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul512, add380)
            rms_norm258: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add382, model_layers_19_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv643 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight4, model_layers_19_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims513: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv643, axes=None)
            matmul513: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm258, permute_dims513, out_dtype="void")
            split127: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul513, indices_or_sections=2, axis=-1)
            split_0127: R.Tensor((batch_size, 1, 11008), dtype="float16") = split127[0]
            split_1127: R.Tensor((batch_size, 1, 11008), dtype="float16") = split127[1]
            silu127: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0127)
            mul127: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu127, split_1127)
            lv644 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight4, model_layers_19_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims514: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv644, axes=None)
            matmul514: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul127, permute_dims514, out_dtype="void")
            add383: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul514, add382)
            rms_norm259: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add383, model_layers_20_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv645 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight4, model_layers_20_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims515: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv645, axes=None)
            matmul515: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm259, permute_dims515, out_dtype="void")
            add384: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul515, model_layers_20_self_attn_c_attn_bias4)
            reshape512: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add384, R.shape([batch_size, 1, 20, 128]))
            reshape513: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape512, R.shape([batch_size, 20, 128]))
            lv646 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape513), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape514: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv646, R.shape([batch_size, 1, 16, 128]))
            reshape515: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape514, R.shape([batch_size, 1, 2048]))
            lv647 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight4, model_layers_20_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims516: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv647, axes=None)
            matmul516: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape515, permute_dims516, out_dtype="void")
            add385: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul516, add383)
            rms_norm260: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add385, model_layers_20_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv648 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight4, model_layers_20_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims517: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv648, axes=None)
            matmul517: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm260, permute_dims517, out_dtype="void")
            split128: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul517, indices_or_sections=2, axis=-1)
            split_0128: R.Tensor((batch_size, 1, 11008), dtype="float16") = split128[0]
            split_1128: R.Tensor((batch_size, 1, 11008), dtype="float16") = split128[1]
            silu128: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0128)
            mul128: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu128, split_1128)
            lv649 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight4, model_layers_20_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims518: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv649, axes=None)
            matmul518: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul128, permute_dims518, out_dtype="void")
            add386: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul518, add385)
            rms_norm261: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add386, model_layers_21_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv650 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight4, model_layers_21_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims519: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv650, axes=None)
            matmul519: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm261, permute_dims519, out_dtype="void")
            add387: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul519, model_layers_21_self_attn_c_attn_bias4)
            reshape516: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add387, R.shape([batch_size, 1, 20, 128]))
            reshape517: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape516, R.shape([batch_size, 20, 128]))
            lv651 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape517), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape518: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv651, R.shape([batch_size, 1, 16, 128]))
            reshape519: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape518, R.shape([batch_size, 1, 2048]))
            lv652 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight4, model_layers_21_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims520: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv652, axes=None)
            matmul520: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape519, permute_dims520, out_dtype="void")
            add388: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul520, add386)
            rms_norm262: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add388, model_layers_21_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv653 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight4, model_layers_21_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims521: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv653, axes=None)
            matmul521: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm262, permute_dims521, out_dtype="void")
            split129: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul521, indices_or_sections=2, axis=-1)
            split_0129: R.Tensor((batch_size, 1, 11008), dtype="float16") = split129[0]
            split_1129: R.Tensor((batch_size, 1, 11008), dtype="float16") = split129[1]
            silu129: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0129)
            mul129: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu129, split_1129)
            lv654 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight4, model_layers_21_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims522: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv654, axes=None)
            matmul522: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul129, permute_dims522, out_dtype="void")
            add389: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul522, add388)
            rms_norm263: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add389, model_layers_22_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv655 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight4, model_layers_22_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims523: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv655, axes=None)
            matmul523: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm263, permute_dims523, out_dtype="void")
            add390: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul523, model_layers_22_self_attn_c_attn_bias4)
            reshape520: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add390, R.shape([batch_size, 1, 20, 128]))
            reshape521: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape520, R.shape([batch_size, 20, 128]))
            lv656 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape521), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape522: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv656, R.shape([batch_size, 1, 16, 128]))
            reshape523: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape522, R.shape([batch_size, 1, 2048]))
            lv657 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight4, model_layers_22_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims524: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv657, axes=None)
            matmul524: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape523, permute_dims524, out_dtype="void")
            add391: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul524, add389)
            rms_norm264: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add391, model_layers_22_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv658 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight4, model_layers_22_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims525: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv658, axes=None)
            matmul525: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm264, permute_dims525, out_dtype="void")
            split130: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul525, indices_or_sections=2, axis=-1)
            split_0130: R.Tensor((batch_size, 1, 11008), dtype="float16") = split130[0]
            split_1130: R.Tensor((batch_size, 1, 11008), dtype="float16") = split130[1]
            silu130: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0130)
            mul130: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu130, split_1130)
            lv659 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight4, model_layers_22_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims526: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv659, axes=None)
            matmul526: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul130, permute_dims526, out_dtype="void")
            add392: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul526, add391)
            rms_norm265: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add392, model_layers_23_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv660 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight4, model_layers_23_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims527: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv660, axes=None)
            matmul527: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm265, permute_dims527, out_dtype="void")
            add393: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul527, model_layers_23_self_attn_c_attn_bias4)
            reshape524: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add393, R.shape([batch_size, 1, 20, 128]))
            reshape525: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape524, R.shape([batch_size, 20, 128]))
            lv661 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape525), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape526: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv661, R.shape([batch_size, 1, 16, 128]))
            reshape527: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape526, R.shape([batch_size, 1, 2048]))
            lv662 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight4, model_layers_23_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims528: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv662, axes=None)
            matmul528: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape527, permute_dims528, out_dtype="void")
            add394: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul528, add392)
            rms_norm266: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add394, model_layers_23_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv663 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight4, model_layers_23_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims529: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv663, axes=None)
            matmul529: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm266, permute_dims529, out_dtype="void")
            split131: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul529, indices_or_sections=2, axis=-1)
            split_0131: R.Tensor((batch_size, 1, 11008), dtype="float16") = split131[0]
            split_1131: R.Tensor((batch_size, 1, 11008), dtype="float16") = split131[1]
            silu131: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0131)
            mul131: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu131, split_1131)
            lv664 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight4, model_layers_23_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims530: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv664, axes=None)
            matmul530: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul131, permute_dims530, out_dtype="void")
            add395: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul530, add394)
            rms_norm267: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add395, model_layers_24_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv665 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight4, model_layers_24_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims531: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv665, axes=None)
            matmul531: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm267, permute_dims531, out_dtype="void")
            add396: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul531, model_layers_24_self_attn_c_attn_bias4)
            reshape528: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add396, R.shape([batch_size, 1, 20, 128]))
            reshape529: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape528, R.shape([batch_size, 20, 128]))
            lv666 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape529), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape530: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv666, R.shape([batch_size, 1, 16, 128]))
            reshape531: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape530, R.shape([batch_size, 1, 2048]))
            lv667 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight4, model_layers_24_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims532: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv667, axes=None)
            matmul532: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape531, permute_dims532, out_dtype="void")
            add397: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul532, add395)
            rms_norm268: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add397, model_layers_24_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv668 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight4, model_layers_24_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims533: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv668, axes=None)
            matmul533: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm268, permute_dims533, out_dtype="void")
            split132: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul533, indices_or_sections=2, axis=-1)
            split_0132: R.Tensor((batch_size, 1, 11008), dtype="float16") = split132[0]
            split_1132: R.Tensor((batch_size, 1, 11008), dtype="float16") = split132[1]
            silu132: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0132)
            mul132: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu132, split_1132)
            lv669 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight4, model_layers_24_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims534: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv669, axes=None)
            matmul534: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul132, permute_dims534, out_dtype="void")
            add398: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul534, add397)
            rms_norm269: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add398, model_layers_25_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv670 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight4, model_layers_25_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims535: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv670, axes=None)
            matmul535: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm269, permute_dims535, out_dtype="void")
            add399: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul535, model_layers_25_self_attn_c_attn_bias4)
            reshape532: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add399, R.shape([batch_size, 1, 20, 128]))
            reshape533: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape532, R.shape([batch_size, 20, 128]))
            lv671 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape533), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape534: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv671, R.shape([batch_size, 1, 16, 128]))
            reshape535: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape534, R.shape([batch_size, 1, 2048]))
            lv672 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight4, model_layers_25_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims536: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv672, axes=None)
            matmul536: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape535, permute_dims536, out_dtype="void")
            add400: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul536, add398)
            rms_norm270: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add400, model_layers_25_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv673 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight4, model_layers_25_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims537: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv673, axes=None)
            matmul537: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm270, permute_dims537, out_dtype="void")
            split133: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul537, indices_or_sections=2, axis=-1)
            split_0133: R.Tensor((batch_size, 1, 11008), dtype="float16") = split133[0]
            split_1133: R.Tensor((batch_size, 1, 11008), dtype="float16") = split133[1]
            silu133: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0133)
            mul133: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu133, split_1133)
            lv674 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight4, model_layers_25_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims538: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv674, axes=None)
            matmul538: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul133, permute_dims538, out_dtype="void")
            add401: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul538, add400)
            rms_norm271: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add401, model_layers_26_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv675 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight4, model_layers_26_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims539: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv675, axes=None)
            matmul539: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm271, permute_dims539, out_dtype="void")
            add402: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul539, model_layers_26_self_attn_c_attn_bias4)
            reshape536: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add402, R.shape([batch_size, 1, 20, 128]))
            reshape537: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape536, R.shape([batch_size, 20, 128]))
            lv676 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape537), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape538: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv676, R.shape([batch_size, 1, 16, 128]))
            reshape539: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape538, R.shape([batch_size, 1, 2048]))
            lv677 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight4, model_layers_26_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims540: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv677, axes=None)
            matmul540: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape539, permute_dims540, out_dtype="void")
            add403: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul540, add401)
            rms_norm272: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add403, model_layers_26_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv678 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight4, model_layers_26_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims541: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv678, axes=None)
            matmul541: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm272, permute_dims541, out_dtype="void")
            split134: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul541, indices_or_sections=2, axis=-1)
            split_0134: R.Tensor((batch_size, 1, 11008), dtype="float16") = split134[0]
            split_1134: R.Tensor((batch_size, 1, 11008), dtype="float16") = split134[1]
            silu134: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0134)
            mul134: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu134, split_1134)
            lv679 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight4, model_layers_26_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims542: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv679, axes=None)
            matmul542: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul134, permute_dims542, out_dtype="void")
            add404: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul542, add403)
            rms_norm273: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add404, model_layers_27_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv680 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight4, model_layers_27_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims543: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv680, axes=None)
            matmul543: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm273, permute_dims543, out_dtype="void")
            add405: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul543, model_layers_27_self_attn_c_attn_bias4)
            reshape540: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add405, R.shape([batch_size, 1, 20, 128]))
            reshape541: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape540, R.shape([batch_size, 20, 128]))
            lv681 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape541), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape542: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv681, R.shape([batch_size, 1, 16, 128]))
            reshape543: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape542, R.shape([batch_size, 1, 2048]))
            lv682 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight4, model_layers_27_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims544: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv682, axes=None)
            matmul544: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape543, permute_dims544, out_dtype="void")
            add406: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul544, add404)
            rms_norm274: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add406, model_layers_27_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv683 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight4, model_layers_27_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims545: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv683, axes=None)
            matmul545: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm274, permute_dims545, out_dtype="void")
            split135: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul545, indices_or_sections=2, axis=-1)
            split_0135: R.Tensor((batch_size, 1, 11008), dtype="float16") = split135[0]
            split_1135: R.Tensor((batch_size, 1, 11008), dtype="float16") = split135[1]
            silu135: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0135)
            mul135: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu135, split_1135)
            lv684 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight4, model_layers_27_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims546: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv684, axes=None)
            matmul546: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul135, permute_dims546, out_dtype="void")
            add407: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul546, add406)
            rms_norm275: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add407, model_layers_28_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv685 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight4, model_layers_28_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims547: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv685, axes=None)
            matmul547: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm275, permute_dims547, out_dtype="void")
            add408: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul547, model_layers_28_self_attn_c_attn_bias4)
            reshape544: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add408, R.shape([batch_size, 1, 20, 128]))
            reshape545: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape544, R.shape([batch_size, 20, 128]))
            lv686 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape545), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape546: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv686, R.shape([batch_size, 1, 16, 128]))
            reshape547: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape546, R.shape([batch_size, 1, 2048]))
            lv687 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight4, model_layers_28_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims548: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv687, axes=None)
            matmul548: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape547, permute_dims548, out_dtype="void")
            add409: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul548, add407)
            rms_norm276: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add409, model_layers_28_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv688 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight4, model_layers_28_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims549: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv688, axes=None)
            matmul549: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm276, permute_dims549, out_dtype="void")
            split136: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul549, indices_or_sections=2, axis=-1)
            split_0136: R.Tensor((batch_size, 1, 11008), dtype="float16") = split136[0]
            split_1136: R.Tensor((batch_size, 1, 11008), dtype="float16") = split136[1]
            silu136: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0136)
            mul136: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu136, split_1136)
            lv689 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight4, model_layers_28_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims550: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv689, axes=None)
            matmul550: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul136, permute_dims550, out_dtype="void")
            add410: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul550, add409)
            rms_norm277: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add410, model_layers_29_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv690 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight4, model_layers_29_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims551: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv690, axes=None)
            matmul551: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm277, permute_dims551, out_dtype="void")
            add411: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul551, model_layers_29_self_attn_c_attn_bias4)
            reshape548: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add411, R.shape([batch_size, 1, 20, 128]))
            reshape549: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape548, R.shape([batch_size, 20, 128]))
            lv691 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape549), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape550: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv691, R.shape([batch_size, 1, 16, 128]))
            reshape551: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape550, R.shape([batch_size, 1, 2048]))
            lv692 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight4, model_layers_29_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims552: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv692, axes=None)
            matmul552: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape551, permute_dims552, out_dtype="void")
            add412: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul552, add410)
            rms_norm278: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add412, model_layers_29_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv693 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight4, model_layers_29_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims553: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv693, axes=None)
            matmul553: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm278, permute_dims553, out_dtype="void")
            split137: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul553, indices_or_sections=2, axis=-1)
            split_0137: R.Tensor((batch_size, 1, 11008), dtype="float16") = split137[0]
            split_1137: R.Tensor((batch_size, 1, 11008), dtype="float16") = split137[1]
            silu137: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0137)
            mul137: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu137, split_1137)
            lv694 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight4, model_layers_29_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims554: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv694, axes=None)
            matmul554: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul137, permute_dims554, out_dtype="void")
            add413: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul554, add412)
            rms_norm279: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add413, model_layers_30_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv695 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight4, model_layers_30_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims555: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv695, axes=None)
            matmul555: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm279, permute_dims555, out_dtype="void")
            add414: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul555, model_layers_30_self_attn_c_attn_bias4)
            reshape552: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add414, R.shape([batch_size, 1, 20, 128]))
            reshape553: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape552, R.shape([batch_size, 20, 128]))
            lv696 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape553), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape554: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv696, R.shape([batch_size, 1, 16, 128]))
            reshape555: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape554, R.shape([batch_size, 1, 2048]))
            lv697 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight4, model_layers_30_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims556: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv697, axes=None)
            matmul556: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape555, permute_dims556, out_dtype="void")
            add415: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul556, add413)
            rms_norm280: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add415, model_layers_30_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv698 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight4, model_layers_30_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims557: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv698, axes=None)
            matmul557: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm280, permute_dims557, out_dtype="void")
            split138: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul557, indices_or_sections=2, axis=-1)
            split_0138: R.Tensor((batch_size, 1, 11008), dtype="float16") = split138[0]
            split_1138: R.Tensor((batch_size, 1, 11008), dtype="float16") = split138[1]
            silu138: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0138)
            mul138: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu138, split_1138)
            lv699 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight4, model_layers_30_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims558: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv699, axes=None)
            matmul558: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul138, permute_dims558, out_dtype="void")
            add416: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul558, add415)
            rms_norm281: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add416, model_layers_31_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv700 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight4, model_layers_31_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims559: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv700, axes=None)
            matmul559: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm281, permute_dims559, out_dtype="void")
            add417: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul559, model_layers_31_self_attn_c_attn_bias4)
            reshape556: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add417, R.shape([batch_size, 1, 20, 128]))
            reshape557: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape556, R.shape([batch_size, 20, 128]))
            lv701 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape557), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape558: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv701, R.shape([batch_size, 1, 16, 128]))
            reshape559: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape558, R.shape([batch_size, 1, 2048]))
            lv702 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight4, model_layers_31_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims560: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv702, axes=None)
            matmul560: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape559, permute_dims560, out_dtype="void")
            add418: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul560, add416)
            rms_norm282: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add418, model_layers_31_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv703 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight4, model_layers_31_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims561: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv703, axes=None)
            matmul561: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm282, permute_dims561, out_dtype="void")
            split139: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul561, indices_or_sections=2, axis=-1)
            split_0139: R.Tensor((batch_size, 1, 11008), dtype="float16") = split139[0]
            split_1139: R.Tensor((batch_size, 1, 11008), dtype="float16") = split139[1]
            silu139: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0139)
            mul139: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu139, split_1139)
            lv704 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight4, model_layers_31_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims562: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv704, axes=None)
            matmul562: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul139, permute_dims562, out_dtype="void")
            add419: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul562, add418)
            rms_norm283: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add419, model_layers_32_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv705 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight4, model_layers_32_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims563: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv705, axes=None)
            matmul563: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm283, permute_dims563, out_dtype="void")
            add420: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul563, model_layers_32_self_attn_c_attn_bias4)
            reshape560: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add420, R.shape([batch_size, 1, 20, 128]))
            reshape561: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape560, R.shape([batch_size, 20, 128]))
            lv706 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape561), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape562: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv706, R.shape([batch_size, 1, 16, 128]))
            reshape563: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape562, R.shape([batch_size, 1, 2048]))
            lv707 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight4, model_layers_32_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims564: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv707, axes=None)
            matmul564: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape563, permute_dims564, out_dtype="void")
            add421: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul564, add419)
            rms_norm284: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add421, model_layers_32_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv708 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight4, model_layers_32_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims565: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv708, axes=None)
            matmul565: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm284, permute_dims565, out_dtype="void")
            split140: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul565, indices_or_sections=2, axis=-1)
            split_0140: R.Tensor((batch_size, 1, 11008), dtype="float16") = split140[0]
            split_1140: R.Tensor((batch_size, 1, 11008), dtype="float16") = split140[1]
            silu140: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0140)
            mul140: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu140, split_1140)
            lv709 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight4, model_layers_32_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims566: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv709, axes=None)
            matmul566: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul140, permute_dims566, out_dtype="void")
            add422: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul566, add421)
            rms_norm285: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add422, model_layers_33_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv710 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight4, model_layers_33_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims567: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv710, axes=None)
            matmul567: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm285, permute_dims567, out_dtype="void")
            add423: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul567, model_layers_33_self_attn_c_attn_bias4)
            reshape564: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add423, R.shape([batch_size, 1, 20, 128]))
            reshape565: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape564, R.shape([batch_size, 20, 128]))
            lv711 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape565), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape566: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv711, R.shape([batch_size, 1, 16, 128]))
            reshape567: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape566, R.shape([batch_size, 1, 2048]))
            lv712 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight4, model_layers_33_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims568: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv712, axes=None)
            matmul568: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape567, permute_dims568, out_dtype="void")
            add424: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul568, add422)
            rms_norm286: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add424, model_layers_33_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv713 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight4, model_layers_33_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims569: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv713, axes=None)
            matmul569: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm286, permute_dims569, out_dtype="void")
            split141: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul569, indices_or_sections=2, axis=-1)
            split_0141: R.Tensor((batch_size, 1, 11008), dtype="float16") = split141[0]
            split_1141: R.Tensor((batch_size, 1, 11008), dtype="float16") = split141[1]
            silu141: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0141)
            mul141: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu141, split_1141)
            lv714 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight4, model_layers_33_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims570: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv714, axes=None)
            matmul570: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul141, permute_dims570, out_dtype="void")
            add425: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul570, add424)
            rms_norm287: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add425, model_layers_34_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv715 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight4, model_layers_34_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims571: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv715, axes=None)
            matmul571: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm287, permute_dims571, out_dtype="void")
            add426: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul571, model_layers_34_self_attn_c_attn_bias4)
            reshape568: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add426, R.shape([batch_size, 1, 20, 128]))
            reshape569: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape568, R.shape([batch_size, 20, 128]))
            lv716 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape569), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape570: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv716, R.shape([batch_size, 1, 16, 128]))
            reshape571: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape570, R.shape([batch_size, 1, 2048]))
            lv717 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight4, model_layers_34_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims572: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv717, axes=None)
            matmul572: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape571, permute_dims572, out_dtype="void")
            add427: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul572, add425)
            rms_norm288: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add427, model_layers_34_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv718 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight4, model_layers_34_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims573: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv718, axes=None)
            matmul573: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm288, permute_dims573, out_dtype="void")
            split142: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul573, indices_or_sections=2, axis=-1)
            split_0142: R.Tensor((batch_size, 1, 11008), dtype="float16") = split142[0]
            split_1142: R.Tensor((batch_size, 1, 11008), dtype="float16") = split142[1]
            silu142: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0142)
            mul142: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu142, split_1142)
            lv719 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight4, model_layers_34_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims574: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv719, axes=None)
            matmul574: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul142, permute_dims574, out_dtype="void")
            add428: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul574, add427)
            rms_norm289: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add428, model_layers_35_input_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv720 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight4, model_layers_35_self_attn_c_attn_q_scale4), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims575: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv720, axes=None)
            matmul575: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.matmul(rms_norm289, permute_dims575, out_dtype="void")
            add429: R.Tensor((batch_size, 1, 2560), dtype="float16") = R.add(matmul575, model_layers_35_self_attn_c_attn_bias4)
            reshape572: R.Tensor((batch_size, 1, 20, 128), dtype="float16") = R.reshape(add429, R.shape([batch_size, 1, 20, 128]))
            reshape573: R.Tensor((batch_size, 20, 128), dtype="float16") = R.reshape(reshape572, R.shape([batch_size, 20, 128]))
            lv721 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape573), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape574: R.Tensor((batch_size, 1, 16, 128), dtype="float16") = R.reshape(lv721, R.shape([batch_size, 1, 16, 128]))
            reshape575: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.reshape(reshape574, R.shape([batch_size, 1, 2048]))
            lv722 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight4, model_layers_35_self_attn_o_proj_q_scale4), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims576: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv722, axes=None)
            matmul576: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(reshape575, permute_dims576, out_dtype="void")
            add430: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul576, add428)
            rms_norm290: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add430, model_layers_35_post_attention_layernorm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv723 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight4, model_layers_35_mlp_gate_up_proj_q_scale4), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims577: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv723, axes=None)
            matmul577: R.Tensor((batch_size, 1, 22016), dtype="float16") = R.matmul(rms_norm290, permute_dims577, out_dtype="void")
            split143: R.Tuple(R.Tensor((batch_size, 1, 11008), dtype="float16"), R.Tensor((batch_size, 1, 11008), dtype="float16")) = R.split(matmul577, indices_or_sections=2, axis=-1)
            split_0143: R.Tensor((batch_size, 1, 11008), dtype="float16") = split143[0]
            split_1143: R.Tensor((batch_size, 1, 11008), dtype="float16") = split143[1]
            silu143: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.nn.silu(split_0143)
            mul143: R.Tensor((batch_size, 1, 11008), dtype="float16") = R.multiply(silu143, split_1143)
            lv724 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight4, model_layers_35_mlp_down_proj_q_scale4), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims578: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv724, axes=None)
            matmul578: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.matmul(mul143, permute_dims578, out_dtype="void")
            add431: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.add(matmul578, add430)
            rms_norm291: R.Tensor((batch_size, 1, 2048), dtype="float16") = R.nn.rms_norm(add431, model_norm_weight4, axes=[-1], epsilon=9.9999999999999995e-07)
            lv725 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight4, model_embed_tokens_q_scale4), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            permute_dims579: R.Tensor((2048, 151936), dtype="float16") = R.permute_dims(lv725, axes=None)
            matmul579: R.Tensor((batch_size, 1, 151936), dtype="float32") = R.matmul(rms_norm291, permute_dims579, out_dtype="float32")
            gv4: R.Tuple(R.Tensor((batch_size, 1, 151936), dtype="float32"), R.Object) = matmul579, paged_kv_cache
            R.output(gv4)
        return gv4

    @R.function
    def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight3: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale3: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm146: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(input_embeds, model_layers_0_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv364 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight3, model_layers_0_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims290: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv364, axes=None)
            matmul290: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm146, permute_dims290, out_dtype="void")
            add216: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul290, model_layers_0_self_attn_c_attn_bias3)
            reshape288: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add216, R.shape([1, seq_len, 20, 128]))
            reshape289: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape288, R.shape([seq_len, 20, 128]))
            lv365 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape289), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape290: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv365, R.shape([1, seq_len, 16, 128]))
            reshape291: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape290, R.shape([1, seq_len, 2048]))
            lv366 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight3, model_layers_0_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims291: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv366, axes=None)
            matmul291: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape291, permute_dims291, out_dtype="void")
            add217: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul291, input_embeds)
            rms_norm147: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add217, model_layers_0_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv367 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight3, model_layers_0_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims292: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv367, axes=None)
            matmul292: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm147, permute_dims292, out_dtype="void")
            split72: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul292, indices_or_sections=2, axis=-1)
            split_072: R.Tensor((1, seq_len, 11008), dtype="float16") = split72[0]
            split_172: R.Tensor((1, seq_len, 11008), dtype="float16") = split72[1]
            silu72: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_072)
            mul72: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu72, split_172)
            lv368 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight3, model_layers_0_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims293: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv368, axes=None)
            matmul293: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul72, permute_dims293, out_dtype="void")
            add218: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul293, add217)
            rms_norm148: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add218, model_layers_1_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv369 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight3, model_layers_1_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims294: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv369, axes=None)
            matmul294: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm148, permute_dims294, out_dtype="void")
            add219: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul294, model_layers_1_self_attn_c_attn_bias3)
            reshape292: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add219, R.shape([1, seq_len, 20, 128]))
            reshape293: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape292, R.shape([seq_len, 20, 128]))
            lv370 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape293), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape294: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv370, R.shape([1, seq_len, 16, 128]))
            reshape295: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape294, R.shape([1, seq_len, 2048]))
            lv371 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight3, model_layers_1_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims295: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv371, axes=None)
            matmul295: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape295, permute_dims295, out_dtype="void")
            add220: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul295, add218)
            rms_norm149: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add220, model_layers_1_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv372 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight3, model_layers_1_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims296: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv372, axes=None)
            matmul296: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm149, permute_dims296, out_dtype="void")
            split73: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul296, indices_or_sections=2, axis=-1)
            split_073: R.Tensor((1, seq_len, 11008), dtype="float16") = split73[0]
            split_173: R.Tensor((1, seq_len, 11008), dtype="float16") = split73[1]
            silu73: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_073)
            mul73: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu73, split_173)
            lv373 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight3, model_layers_1_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims297: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv373, axes=None)
            matmul297: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul73, permute_dims297, out_dtype="void")
            add221: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul297, add220)
            rms_norm150: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add221, model_layers_2_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv374 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight3, model_layers_2_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims298: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv374, axes=None)
            matmul298: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm150, permute_dims298, out_dtype="void")
            add222: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul298, model_layers_2_self_attn_c_attn_bias3)
            reshape296: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add222, R.shape([1, seq_len, 20, 128]))
            reshape297: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape296, R.shape([seq_len, 20, 128]))
            lv375 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape297), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape298: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv375, R.shape([1, seq_len, 16, 128]))
            reshape299: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape298, R.shape([1, seq_len, 2048]))
            lv376 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight3, model_layers_2_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims299: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv376, axes=None)
            matmul299: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape299, permute_dims299, out_dtype="void")
            add223: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul299, add221)
            rms_norm151: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add223, model_layers_2_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv377 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight3, model_layers_2_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims300: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv377, axes=None)
            matmul300: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm151, permute_dims300, out_dtype="void")
            split74: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul300, indices_or_sections=2, axis=-1)
            split_074: R.Tensor((1, seq_len, 11008), dtype="float16") = split74[0]
            split_174: R.Tensor((1, seq_len, 11008), dtype="float16") = split74[1]
            silu74: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_074)
            mul74: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu74, split_174)
            lv378 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight3, model_layers_2_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims301: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv378, axes=None)
            matmul301: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul74, permute_dims301, out_dtype="void")
            add224: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul301, add223)
            rms_norm152: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add224, model_layers_3_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv379 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight3, model_layers_3_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims302: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv379, axes=None)
            matmul302: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm152, permute_dims302, out_dtype="void")
            add225: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul302, model_layers_3_self_attn_c_attn_bias3)
            reshape300: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add225, R.shape([1, seq_len, 20, 128]))
            reshape301: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape300, R.shape([seq_len, 20, 128]))
            lv380 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape301), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape302: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv380, R.shape([1, seq_len, 16, 128]))
            reshape303: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape302, R.shape([1, seq_len, 2048]))
            lv381 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight3, model_layers_3_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims303: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv381, axes=None)
            matmul303: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape303, permute_dims303, out_dtype="void")
            add226: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul303, add224)
            rms_norm153: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add226, model_layers_3_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv382 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight3, model_layers_3_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims304: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv382, axes=None)
            matmul304: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm153, permute_dims304, out_dtype="void")
            split75: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul304, indices_or_sections=2, axis=-1)
            split_075: R.Tensor((1, seq_len, 11008), dtype="float16") = split75[0]
            split_175: R.Tensor((1, seq_len, 11008), dtype="float16") = split75[1]
            silu75: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_075)
            mul75: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu75, split_175)
            lv383 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight3, model_layers_3_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims305: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv383, axes=None)
            matmul305: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul75, permute_dims305, out_dtype="void")
            add227: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul305, add226)
            rms_norm154: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add227, model_layers_4_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv384 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight3, model_layers_4_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims306: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv384, axes=None)
            matmul306: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm154, permute_dims306, out_dtype="void")
            add228: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul306, model_layers_4_self_attn_c_attn_bias3)
            reshape304: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add228, R.shape([1, seq_len, 20, 128]))
            reshape305: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape304, R.shape([seq_len, 20, 128]))
            lv385 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape305), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape306: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv385, R.shape([1, seq_len, 16, 128]))
            reshape307: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape306, R.shape([1, seq_len, 2048]))
            lv386 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight3, model_layers_4_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims307: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv386, axes=None)
            matmul307: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape307, permute_dims307, out_dtype="void")
            add229: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul307, add227)
            rms_norm155: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add229, model_layers_4_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv387 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight3, model_layers_4_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims308: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv387, axes=None)
            matmul308: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm155, permute_dims308, out_dtype="void")
            split76: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul308, indices_or_sections=2, axis=-1)
            split_076: R.Tensor((1, seq_len, 11008), dtype="float16") = split76[0]
            split_176: R.Tensor((1, seq_len, 11008), dtype="float16") = split76[1]
            silu76: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_076)
            mul76: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu76, split_176)
            lv388 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight3, model_layers_4_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims309: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv388, axes=None)
            matmul309: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul76, permute_dims309, out_dtype="void")
            add230: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul309, add229)
            rms_norm156: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add230, model_layers_5_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv389 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight3, model_layers_5_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims310: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv389, axes=None)
            matmul310: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm156, permute_dims310, out_dtype="void")
            add231: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul310, model_layers_5_self_attn_c_attn_bias3)
            reshape308: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add231, R.shape([1, seq_len, 20, 128]))
            reshape309: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape308, R.shape([seq_len, 20, 128]))
            lv390 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape309), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape310: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv390, R.shape([1, seq_len, 16, 128]))
            reshape311: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape310, R.shape([1, seq_len, 2048]))
            lv391 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight3, model_layers_5_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims311: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv391, axes=None)
            matmul311: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape311, permute_dims311, out_dtype="void")
            add232: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul311, add230)
            rms_norm157: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add232, model_layers_5_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv392 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight3, model_layers_5_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims312: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv392, axes=None)
            matmul312: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm157, permute_dims312, out_dtype="void")
            split77: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul312, indices_or_sections=2, axis=-1)
            split_077: R.Tensor((1, seq_len, 11008), dtype="float16") = split77[0]
            split_177: R.Tensor((1, seq_len, 11008), dtype="float16") = split77[1]
            silu77: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_077)
            mul77: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu77, split_177)
            lv393 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight3, model_layers_5_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims313: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv393, axes=None)
            matmul313: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul77, permute_dims313, out_dtype="void")
            add233: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul313, add232)
            rms_norm158: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add233, model_layers_6_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv394 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight3, model_layers_6_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims314: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv394, axes=None)
            matmul314: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm158, permute_dims314, out_dtype="void")
            add234: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul314, model_layers_6_self_attn_c_attn_bias3)
            reshape312: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add234, R.shape([1, seq_len, 20, 128]))
            reshape313: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape312, R.shape([seq_len, 20, 128]))
            lv395 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape313), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape314: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv395, R.shape([1, seq_len, 16, 128]))
            reshape315: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape314, R.shape([1, seq_len, 2048]))
            lv396 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight3, model_layers_6_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims315: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv396, axes=None)
            matmul315: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape315, permute_dims315, out_dtype="void")
            add235: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul315, add233)
            rms_norm159: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add235, model_layers_6_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv397 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight3, model_layers_6_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims316: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv397, axes=None)
            matmul316: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm159, permute_dims316, out_dtype="void")
            split78: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul316, indices_or_sections=2, axis=-1)
            split_078: R.Tensor((1, seq_len, 11008), dtype="float16") = split78[0]
            split_178: R.Tensor((1, seq_len, 11008), dtype="float16") = split78[1]
            silu78: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_078)
            mul78: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu78, split_178)
            lv398 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight3, model_layers_6_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims317: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv398, axes=None)
            matmul317: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul78, permute_dims317, out_dtype="void")
            add236: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul317, add235)
            rms_norm160: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add236, model_layers_7_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv399 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight3, model_layers_7_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims318: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv399, axes=None)
            matmul318: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm160, permute_dims318, out_dtype="void")
            add237: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul318, model_layers_7_self_attn_c_attn_bias3)
            reshape316: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add237, R.shape([1, seq_len, 20, 128]))
            reshape317: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape316, R.shape([seq_len, 20, 128]))
            lv400 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape317), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape318: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv400, R.shape([1, seq_len, 16, 128]))
            reshape319: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape318, R.shape([1, seq_len, 2048]))
            lv401 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight3, model_layers_7_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims319: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv401, axes=None)
            matmul319: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape319, permute_dims319, out_dtype="void")
            add238: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul319, add236)
            rms_norm161: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add238, model_layers_7_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv402 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight3, model_layers_7_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims320: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv402, axes=None)
            matmul320: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm161, permute_dims320, out_dtype="void")
            split79: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul320, indices_or_sections=2, axis=-1)
            split_079: R.Tensor((1, seq_len, 11008), dtype="float16") = split79[0]
            split_179: R.Tensor((1, seq_len, 11008), dtype="float16") = split79[1]
            silu79: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_079)
            mul79: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu79, split_179)
            lv403 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight3, model_layers_7_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims321: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv403, axes=None)
            matmul321: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul79, permute_dims321, out_dtype="void")
            add239: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul321, add238)
            rms_norm162: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add239, model_layers_8_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv404 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight3, model_layers_8_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims322: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv404, axes=None)
            matmul322: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm162, permute_dims322, out_dtype="void")
            add240: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul322, model_layers_8_self_attn_c_attn_bias3)
            reshape320: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add240, R.shape([1, seq_len, 20, 128]))
            reshape321: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape320, R.shape([seq_len, 20, 128]))
            lv405 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape321), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape322: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv405, R.shape([1, seq_len, 16, 128]))
            reshape323: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape322, R.shape([1, seq_len, 2048]))
            lv406 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight3, model_layers_8_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims323: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv406, axes=None)
            matmul323: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape323, permute_dims323, out_dtype="void")
            add241: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul323, add239)
            rms_norm163: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add241, model_layers_8_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv407 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight3, model_layers_8_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims324: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv407, axes=None)
            matmul324: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm163, permute_dims324, out_dtype="void")
            split80: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul324, indices_or_sections=2, axis=-1)
            split_080: R.Tensor((1, seq_len, 11008), dtype="float16") = split80[0]
            split_180: R.Tensor((1, seq_len, 11008), dtype="float16") = split80[1]
            silu80: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_080)
            mul80: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu80, split_180)
            lv408 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight3, model_layers_8_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims325: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv408, axes=None)
            matmul325: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul80, permute_dims325, out_dtype="void")
            add242: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul325, add241)
            rms_norm164: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add242, model_layers_9_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv409 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight3, model_layers_9_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims326: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv409, axes=None)
            matmul326: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm164, permute_dims326, out_dtype="void")
            add243: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul326, model_layers_9_self_attn_c_attn_bias3)
            reshape324: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add243, R.shape([1, seq_len, 20, 128]))
            reshape325: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape324, R.shape([seq_len, 20, 128]))
            lv410 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape325), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape326: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv410, R.shape([1, seq_len, 16, 128]))
            reshape327: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape326, R.shape([1, seq_len, 2048]))
            lv411 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight3, model_layers_9_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims327: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv411, axes=None)
            matmul327: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape327, permute_dims327, out_dtype="void")
            add244: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul327, add242)
            rms_norm165: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add244, model_layers_9_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv412 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight3, model_layers_9_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims328: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv412, axes=None)
            matmul328: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm165, permute_dims328, out_dtype="void")
            split81: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul328, indices_or_sections=2, axis=-1)
            split_081: R.Tensor((1, seq_len, 11008), dtype="float16") = split81[0]
            split_181: R.Tensor((1, seq_len, 11008), dtype="float16") = split81[1]
            silu81: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_081)
            mul81: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu81, split_181)
            lv413 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight3, model_layers_9_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims329: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv413, axes=None)
            matmul329: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul81, permute_dims329, out_dtype="void")
            add245: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul329, add244)
            rms_norm166: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add245, model_layers_10_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv414 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight3, model_layers_10_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims330: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv414, axes=None)
            matmul330: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm166, permute_dims330, out_dtype="void")
            add246: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul330, model_layers_10_self_attn_c_attn_bias3)
            reshape328: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add246, R.shape([1, seq_len, 20, 128]))
            reshape329: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape328, R.shape([seq_len, 20, 128]))
            lv415 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape329), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape330: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv415, R.shape([1, seq_len, 16, 128]))
            reshape331: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape330, R.shape([1, seq_len, 2048]))
            lv416 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight3, model_layers_10_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims331: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv416, axes=None)
            matmul331: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape331, permute_dims331, out_dtype="void")
            add247: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul331, add245)
            rms_norm167: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add247, model_layers_10_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv417 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight3, model_layers_10_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims332: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv417, axes=None)
            matmul332: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm167, permute_dims332, out_dtype="void")
            split82: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul332, indices_or_sections=2, axis=-1)
            split_082: R.Tensor((1, seq_len, 11008), dtype="float16") = split82[0]
            split_182: R.Tensor((1, seq_len, 11008), dtype="float16") = split82[1]
            silu82: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_082)
            mul82: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu82, split_182)
            lv418 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight3, model_layers_10_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims333: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv418, axes=None)
            matmul333: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul82, permute_dims333, out_dtype="void")
            add248: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul333, add247)
            rms_norm168: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add248, model_layers_11_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv419 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight3, model_layers_11_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims334: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv419, axes=None)
            matmul334: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm168, permute_dims334, out_dtype="void")
            add249: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul334, model_layers_11_self_attn_c_attn_bias3)
            reshape332: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add249, R.shape([1, seq_len, 20, 128]))
            reshape333: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape332, R.shape([seq_len, 20, 128]))
            lv420 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape333), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape334: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv420, R.shape([1, seq_len, 16, 128]))
            reshape335: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape334, R.shape([1, seq_len, 2048]))
            lv421 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight3, model_layers_11_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims335: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv421, axes=None)
            matmul335: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape335, permute_dims335, out_dtype="void")
            add250: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul335, add248)
            rms_norm169: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add250, model_layers_11_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv422 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight3, model_layers_11_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims336: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv422, axes=None)
            matmul336: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm169, permute_dims336, out_dtype="void")
            split83: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul336, indices_or_sections=2, axis=-1)
            split_083: R.Tensor((1, seq_len, 11008), dtype="float16") = split83[0]
            split_183: R.Tensor((1, seq_len, 11008), dtype="float16") = split83[1]
            silu83: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_083)
            mul83: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu83, split_183)
            lv423 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight3, model_layers_11_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims337: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv423, axes=None)
            matmul337: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul83, permute_dims337, out_dtype="void")
            add251: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul337, add250)
            rms_norm170: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add251, model_layers_12_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv424 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight3, model_layers_12_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims338: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv424, axes=None)
            matmul338: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm170, permute_dims338, out_dtype="void")
            add252: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul338, model_layers_12_self_attn_c_attn_bias3)
            reshape336: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add252, R.shape([1, seq_len, 20, 128]))
            reshape337: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape336, R.shape([seq_len, 20, 128]))
            lv425 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape337), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape338: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv425, R.shape([1, seq_len, 16, 128]))
            reshape339: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape338, R.shape([1, seq_len, 2048]))
            lv426 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight3, model_layers_12_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims339: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv426, axes=None)
            matmul339: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape339, permute_dims339, out_dtype="void")
            add253: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul339, add251)
            rms_norm171: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add253, model_layers_12_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv427 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight3, model_layers_12_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims340: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv427, axes=None)
            matmul340: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm171, permute_dims340, out_dtype="void")
            split84: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul340, indices_or_sections=2, axis=-1)
            split_084: R.Tensor((1, seq_len, 11008), dtype="float16") = split84[0]
            split_184: R.Tensor((1, seq_len, 11008), dtype="float16") = split84[1]
            silu84: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_084)
            mul84: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu84, split_184)
            lv428 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight3, model_layers_12_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims341: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv428, axes=None)
            matmul341: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul84, permute_dims341, out_dtype="void")
            add254: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul341, add253)
            rms_norm172: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add254, model_layers_13_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv429 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight3, model_layers_13_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims342: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv429, axes=None)
            matmul342: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm172, permute_dims342, out_dtype="void")
            add255: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul342, model_layers_13_self_attn_c_attn_bias3)
            reshape340: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add255, R.shape([1, seq_len, 20, 128]))
            reshape341: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape340, R.shape([seq_len, 20, 128]))
            lv430 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape341), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape342: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv430, R.shape([1, seq_len, 16, 128]))
            reshape343: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape342, R.shape([1, seq_len, 2048]))
            lv431 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight3, model_layers_13_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims343: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv431, axes=None)
            matmul343: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape343, permute_dims343, out_dtype="void")
            add256: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul343, add254)
            rms_norm173: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add256, model_layers_13_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv432 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight3, model_layers_13_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims344: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv432, axes=None)
            matmul344: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm173, permute_dims344, out_dtype="void")
            split85: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul344, indices_or_sections=2, axis=-1)
            split_085: R.Tensor((1, seq_len, 11008), dtype="float16") = split85[0]
            split_185: R.Tensor((1, seq_len, 11008), dtype="float16") = split85[1]
            silu85: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_085)
            mul85: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu85, split_185)
            lv433 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight3, model_layers_13_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims345: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv433, axes=None)
            matmul345: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul85, permute_dims345, out_dtype="void")
            add257: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul345, add256)
            rms_norm174: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add257, model_layers_14_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv434 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight3, model_layers_14_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims346: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv434, axes=None)
            matmul346: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm174, permute_dims346, out_dtype="void")
            add258: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul346, model_layers_14_self_attn_c_attn_bias3)
            reshape344: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add258, R.shape([1, seq_len, 20, 128]))
            reshape345: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape344, R.shape([seq_len, 20, 128]))
            lv435 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape345), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape346: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv435, R.shape([1, seq_len, 16, 128]))
            reshape347: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape346, R.shape([1, seq_len, 2048]))
            lv436 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight3, model_layers_14_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims347: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv436, axes=None)
            matmul347: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape347, permute_dims347, out_dtype="void")
            add259: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul347, add257)
            rms_norm175: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add259, model_layers_14_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv437 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight3, model_layers_14_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims348: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv437, axes=None)
            matmul348: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm175, permute_dims348, out_dtype="void")
            split86: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul348, indices_or_sections=2, axis=-1)
            split_086: R.Tensor((1, seq_len, 11008), dtype="float16") = split86[0]
            split_186: R.Tensor((1, seq_len, 11008), dtype="float16") = split86[1]
            silu86: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_086)
            mul86: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu86, split_186)
            lv438 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight3, model_layers_14_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims349: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv438, axes=None)
            matmul349: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul86, permute_dims349, out_dtype="void")
            add260: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul349, add259)
            rms_norm176: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add260, model_layers_15_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv439 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight3, model_layers_15_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims350: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv439, axes=None)
            matmul350: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm176, permute_dims350, out_dtype="void")
            add261: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul350, model_layers_15_self_attn_c_attn_bias3)
            reshape348: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add261, R.shape([1, seq_len, 20, 128]))
            reshape349: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape348, R.shape([seq_len, 20, 128]))
            lv440 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape349), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape350: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv440, R.shape([1, seq_len, 16, 128]))
            reshape351: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape350, R.shape([1, seq_len, 2048]))
            lv441 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight3, model_layers_15_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims351: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv441, axes=None)
            matmul351: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape351, permute_dims351, out_dtype="void")
            add262: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul351, add260)
            rms_norm177: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add262, model_layers_15_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv442 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight3, model_layers_15_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims352: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv442, axes=None)
            matmul352: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm177, permute_dims352, out_dtype="void")
            split87: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul352, indices_or_sections=2, axis=-1)
            split_087: R.Tensor((1, seq_len, 11008), dtype="float16") = split87[0]
            split_187: R.Tensor((1, seq_len, 11008), dtype="float16") = split87[1]
            silu87: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_087)
            mul87: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu87, split_187)
            lv443 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight3, model_layers_15_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims353: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv443, axes=None)
            matmul353: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul87, permute_dims353, out_dtype="void")
            add263: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul353, add262)
            rms_norm178: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add263, model_layers_16_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv444 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight3, model_layers_16_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims354: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv444, axes=None)
            matmul354: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm178, permute_dims354, out_dtype="void")
            add264: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul354, model_layers_16_self_attn_c_attn_bias3)
            reshape352: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add264, R.shape([1, seq_len, 20, 128]))
            reshape353: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape352, R.shape([seq_len, 20, 128]))
            lv445 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape353), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape354: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv445, R.shape([1, seq_len, 16, 128]))
            reshape355: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape354, R.shape([1, seq_len, 2048]))
            lv446 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight3, model_layers_16_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims355: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv446, axes=None)
            matmul355: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape355, permute_dims355, out_dtype="void")
            add265: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul355, add263)
            rms_norm179: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add265, model_layers_16_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv447 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight3, model_layers_16_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims356: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv447, axes=None)
            matmul356: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm179, permute_dims356, out_dtype="void")
            split88: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul356, indices_or_sections=2, axis=-1)
            split_088: R.Tensor((1, seq_len, 11008), dtype="float16") = split88[0]
            split_188: R.Tensor((1, seq_len, 11008), dtype="float16") = split88[1]
            silu88: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_088)
            mul88: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu88, split_188)
            lv448 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight3, model_layers_16_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims357: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv448, axes=None)
            matmul357: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul88, permute_dims357, out_dtype="void")
            add266: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul357, add265)
            rms_norm180: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add266, model_layers_17_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv449 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight3, model_layers_17_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims358: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv449, axes=None)
            matmul358: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm180, permute_dims358, out_dtype="void")
            add267: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul358, model_layers_17_self_attn_c_attn_bias3)
            reshape356: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add267, R.shape([1, seq_len, 20, 128]))
            reshape357: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape356, R.shape([seq_len, 20, 128]))
            lv450 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape357), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape358: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv450, R.shape([1, seq_len, 16, 128]))
            reshape359: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape358, R.shape([1, seq_len, 2048]))
            lv451 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight3, model_layers_17_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims359: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv451, axes=None)
            matmul359: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape359, permute_dims359, out_dtype="void")
            add268: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul359, add266)
            rms_norm181: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add268, model_layers_17_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv452 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight3, model_layers_17_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims360: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv452, axes=None)
            matmul360: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm181, permute_dims360, out_dtype="void")
            split89: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul360, indices_or_sections=2, axis=-1)
            split_089: R.Tensor((1, seq_len, 11008), dtype="float16") = split89[0]
            split_189: R.Tensor((1, seq_len, 11008), dtype="float16") = split89[1]
            silu89: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_089)
            mul89: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu89, split_189)
            lv453 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight3, model_layers_17_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims361: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv453, axes=None)
            matmul361: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul89, permute_dims361, out_dtype="void")
            add269: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul361, add268)
            rms_norm182: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add269, model_layers_18_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv454 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight3, model_layers_18_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims362: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv454, axes=None)
            matmul362: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm182, permute_dims362, out_dtype="void")
            add270: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul362, model_layers_18_self_attn_c_attn_bias3)
            reshape360: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add270, R.shape([1, seq_len, 20, 128]))
            reshape361: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape360, R.shape([seq_len, 20, 128]))
            lv455 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape361), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape362: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv455, R.shape([1, seq_len, 16, 128]))
            reshape363: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape362, R.shape([1, seq_len, 2048]))
            lv456 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight3, model_layers_18_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims363: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv456, axes=None)
            matmul363: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape363, permute_dims363, out_dtype="void")
            add271: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul363, add269)
            rms_norm183: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add271, model_layers_18_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv457 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight3, model_layers_18_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims364: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv457, axes=None)
            matmul364: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm183, permute_dims364, out_dtype="void")
            split90: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul364, indices_or_sections=2, axis=-1)
            split_090: R.Tensor((1, seq_len, 11008), dtype="float16") = split90[0]
            split_190: R.Tensor((1, seq_len, 11008), dtype="float16") = split90[1]
            silu90: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_090)
            mul90: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu90, split_190)
            lv458 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight3, model_layers_18_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims365: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv458, axes=None)
            matmul365: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul90, permute_dims365, out_dtype="void")
            add272: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul365, add271)
            rms_norm184: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add272, model_layers_19_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv459 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight3, model_layers_19_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims366: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv459, axes=None)
            matmul366: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm184, permute_dims366, out_dtype="void")
            add273: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul366, model_layers_19_self_attn_c_attn_bias3)
            reshape364: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add273, R.shape([1, seq_len, 20, 128]))
            reshape365: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape364, R.shape([seq_len, 20, 128]))
            lv460 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape365), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape366: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv460, R.shape([1, seq_len, 16, 128]))
            reshape367: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape366, R.shape([1, seq_len, 2048]))
            lv461 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight3, model_layers_19_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims367: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv461, axes=None)
            matmul367: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape367, permute_dims367, out_dtype="void")
            add274: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul367, add272)
            rms_norm185: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add274, model_layers_19_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv462 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight3, model_layers_19_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims368: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv462, axes=None)
            matmul368: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm185, permute_dims368, out_dtype="void")
            split91: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul368, indices_or_sections=2, axis=-1)
            split_091: R.Tensor((1, seq_len, 11008), dtype="float16") = split91[0]
            split_191: R.Tensor((1, seq_len, 11008), dtype="float16") = split91[1]
            silu91: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_091)
            mul91: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu91, split_191)
            lv463 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight3, model_layers_19_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims369: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv463, axes=None)
            matmul369: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul91, permute_dims369, out_dtype="void")
            add275: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul369, add274)
            rms_norm186: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add275, model_layers_20_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv464 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight3, model_layers_20_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims370: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv464, axes=None)
            matmul370: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm186, permute_dims370, out_dtype="void")
            add276: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul370, model_layers_20_self_attn_c_attn_bias3)
            reshape368: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add276, R.shape([1, seq_len, 20, 128]))
            reshape369: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape368, R.shape([seq_len, 20, 128]))
            lv465 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape369), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape370: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv465, R.shape([1, seq_len, 16, 128]))
            reshape371: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape370, R.shape([1, seq_len, 2048]))
            lv466 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight3, model_layers_20_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims371: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv466, axes=None)
            matmul371: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape371, permute_dims371, out_dtype="void")
            add277: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul371, add275)
            rms_norm187: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add277, model_layers_20_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv467 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight3, model_layers_20_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims372: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv467, axes=None)
            matmul372: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm187, permute_dims372, out_dtype="void")
            split92: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul372, indices_or_sections=2, axis=-1)
            split_092: R.Tensor((1, seq_len, 11008), dtype="float16") = split92[0]
            split_192: R.Tensor((1, seq_len, 11008), dtype="float16") = split92[1]
            silu92: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_092)
            mul92: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu92, split_192)
            lv468 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight3, model_layers_20_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims373: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv468, axes=None)
            matmul373: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul92, permute_dims373, out_dtype="void")
            add278: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul373, add277)
            rms_norm188: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add278, model_layers_21_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv469 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight3, model_layers_21_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims374: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv469, axes=None)
            matmul374: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm188, permute_dims374, out_dtype="void")
            add279: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul374, model_layers_21_self_attn_c_attn_bias3)
            reshape372: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add279, R.shape([1, seq_len, 20, 128]))
            reshape373: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape372, R.shape([seq_len, 20, 128]))
            lv470 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape373), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape374: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv470, R.shape([1, seq_len, 16, 128]))
            reshape375: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape374, R.shape([1, seq_len, 2048]))
            lv471 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight3, model_layers_21_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims375: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv471, axes=None)
            matmul375: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape375, permute_dims375, out_dtype="void")
            add280: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul375, add278)
            rms_norm189: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add280, model_layers_21_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv472 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight3, model_layers_21_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims376: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv472, axes=None)
            matmul376: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm189, permute_dims376, out_dtype="void")
            split93: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul376, indices_or_sections=2, axis=-1)
            split_093: R.Tensor((1, seq_len, 11008), dtype="float16") = split93[0]
            split_193: R.Tensor((1, seq_len, 11008), dtype="float16") = split93[1]
            silu93: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_093)
            mul93: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu93, split_193)
            lv473 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight3, model_layers_21_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims377: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv473, axes=None)
            matmul377: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul93, permute_dims377, out_dtype="void")
            add281: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul377, add280)
            rms_norm190: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add281, model_layers_22_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv474 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight3, model_layers_22_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims378: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv474, axes=None)
            matmul378: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm190, permute_dims378, out_dtype="void")
            add282: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul378, model_layers_22_self_attn_c_attn_bias3)
            reshape376: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add282, R.shape([1, seq_len, 20, 128]))
            reshape377: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape376, R.shape([seq_len, 20, 128]))
            lv475 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape377), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape378: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv475, R.shape([1, seq_len, 16, 128]))
            reshape379: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape378, R.shape([1, seq_len, 2048]))
            lv476 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight3, model_layers_22_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims379: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv476, axes=None)
            matmul379: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape379, permute_dims379, out_dtype="void")
            add283: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul379, add281)
            rms_norm191: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add283, model_layers_22_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv477 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight3, model_layers_22_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims380: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv477, axes=None)
            matmul380: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm191, permute_dims380, out_dtype="void")
            split94: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul380, indices_or_sections=2, axis=-1)
            split_094: R.Tensor((1, seq_len, 11008), dtype="float16") = split94[0]
            split_194: R.Tensor((1, seq_len, 11008), dtype="float16") = split94[1]
            silu94: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_094)
            mul94: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu94, split_194)
            lv478 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight3, model_layers_22_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims381: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv478, axes=None)
            matmul381: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul94, permute_dims381, out_dtype="void")
            add284: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul381, add283)
            rms_norm192: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add284, model_layers_23_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv479 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight3, model_layers_23_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims382: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv479, axes=None)
            matmul382: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm192, permute_dims382, out_dtype="void")
            add285: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul382, model_layers_23_self_attn_c_attn_bias3)
            reshape380: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add285, R.shape([1, seq_len, 20, 128]))
            reshape381: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape380, R.shape([seq_len, 20, 128]))
            lv480 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape381), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape382: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv480, R.shape([1, seq_len, 16, 128]))
            reshape383: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape382, R.shape([1, seq_len, 2048]))
            lv481 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight3, model_layers_23_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims383: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv481, axes=None)
            matmul383: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape383, permute_dims383, out_dtype="void")
            add286: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul383, add284)
            rms_norm193: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add286, model_layers_23_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv482 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight3, model_layers_23_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims384: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv482, axes=None)
            matmul384: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm193, permute_dims384, out_dtype="void")
            split95: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul384, indices_or_sections=2, axis=-1)
            split_095: R.Tensor((1, seq_len, 11008), dtype="float16") = split95[0]
            split_195: R.Tensor((1, seq_len, 11008), dtype="float16") = split95[1]
            silu95: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_095)
            mul95: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu95, split_195)
            lv483 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight3, model_layers_23_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims385: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv483, axes=None)
            matmul385: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul95, permute_dims385, out_dtype="void")
            add287: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul385, add286)
            rms_norm194: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add287, model_layers_24_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv484 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight3, model_layers_24_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims386: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv484, axes=None)
            matmul386: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm194, permute_dims386, out_dtype="void")
            add288: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul386, model_layers_24_self_attn_c_attn_bias3)
            reshape384: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add288, R.shape([1, seq_len, 20, 128]))
            reshape385: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape384, R.shape([seq_len, 20, 128]))
            lv485 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape385), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape386: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv485, R.shape([1, seq_len, 16, 128]))
            reshape387: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape386, R.shape([1, seq_len, 2048]))
            lv486 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight3, model_layers_24_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims387: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv486, axes=None)
            matmul387: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape387, permute_dims387, out_dtype="void")
            add289: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul387, add287)
            rms_norm195: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add289, model_layers_24_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv487 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight3, model_layers_24_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims388: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv487, axes=None)
            matmul388: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm195, permute_dims388, out_dtype="void")
            split96: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul388, indices_or_sections=2, axis=-1)
            split_096: R.Tensor((1, seq_len, 11008), dtype="float16") = split96[0]
            split_196: R.Tensor((1, seq_len, 11008), dtype="float16") = split96[1]
            silu96: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_096)
            mul96: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu96, split_196)
            lv488 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight3, model_layers_24_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims389: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv488, axes=None)
            matmul389: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul96, permute_dims389, out_dtype="void")
            add290: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul389, add289)
            rms_norm196: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add290, model_layers_25_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv489 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight3, model_layers_25_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims390: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv489, axes=None)
            matmul390: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm196, permute_dims390, out_dtype="void")
            add291: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul390, model_layers_25_self_attn_c_attn_bias3)
            reshape388: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add291, R.shape([1, seq_len, 20, 128]))
            reshape389: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape388, R.shape([seq_len, 20, 128]))
            lv490 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape389), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape390: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv490, R.shape([1, seq_len, 16, 128]))
            reshape391: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape390, R.shape([1, seq_len, 2048]))
            lv491 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight3, model_layers_25_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims391: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv491, axes=None)
            matmul391: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape391, permute_dims391, out_dtype="void")
            add292: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul391, add290)
            rms_norm197: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add292, model_layers_25_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv492 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight3, model_layers_25_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims392: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv492, axes=None)
            matmul392: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm197, permute_dims392, out_dtype="void")
            split97: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul392, indices_or_sections=2, axis=-1)
            split_097: R.Tensor((1, seq_len, 11008), dtype="float16") = split97[0]
            split_197: R.Tensor((1, seq_len, 11008), dtype="float16") = split97[1]
            silu97: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_097)
            mul97: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu97, split_197)
            lv493 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight3, model_layers_25_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims393: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv493, axes=None)
            matmul393: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul97, permute_dims393, out_dtype="void")
            add293: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul393, add292)
            rms_norm198: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add293, model_layers_26_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv494 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight3, model_layers_26_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims394: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv494, axes=None)
            matmul394: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm198, permute_dims394, out_dtype="void")
            add294: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul394, model_layers_26_self_attn_c_attn_bias3)
            reshape392: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add294, R.shape([1, seq_len, 20, 128]))
            reshape393: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape392, R.shape([seq_len, 20, 128]))
            lv495 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape393), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape394: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv495, R.shape([1, seq_len, 16, 128]))
            reshape395: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape394, R.shape([1, seq_len, 2048]))
            lv496 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight3, model_layers_26_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims395: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv496, axes=None)
            matmul395: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape395, permute_dims395, out_dtype="void")
            add295: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul395, add293)
            rms_norm199: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add295, model_layers_26_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv497 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight3, model_layers_26_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims396: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv497, axes=None)
            matmul396: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm199, permute_dims396, out_dtype="void")
            split98: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul396, indices_or_sections=2, axis=-1)
            split_098: R.Tensor((1, seq_len, 11008), dtype="float16") = split98[0]
            split_198: R.Tensor((1, seq_len, 11008), dtype="float16") = split98[1]
            silu98: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_098)
            mul98: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu98, split_198)
            lv498 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight3, model_layers_26_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims397: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv498, axes=None)
            matmul397: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul98, permute_dims397, out_dtype="void")
            add296: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul397, add295)
            rms_norm200: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add296, model_layers_27_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv499 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight3, model_layers_27_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims398: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv499, axes=None)
            matmul398: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm200, permute_dims398, out_dtype="void")
            add297: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul398, model_layers_27_self_attn_c_attn_bias3)
            reshape396: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add297, R.shape([1, seq_len, 20, 128]))
            reshape397: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape396, R.shape([seq_len, 20, 128]))
            lv500 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape397), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape398: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv500, R.shape([1, seq_len, 16, 128]))
            reshape399: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape398, R.shape([1, seq_len, 2048]))
            lv501 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight3, model_layers_27_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims399: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv501, axes=None)
            matmul399: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape399, permute_dims399, out_dtype="void")
            add298: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul399, add296)
            rms_norm201: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add298, model_layers_27_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv502 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight3, model_layers_27_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims400: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv502, axes=None)
            matmul400: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm201, permute_dims400, out_dtype="void")
            split99: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul400, indices_or_sections=2, axis=-1)
            split_099: R.Tensor((1, seq_len, 11008), dtype="float16") = split99[0]
            split_199: R.Tensor((1, seq_len, 11008), dtype="float16") = split99[1]
            silu99: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_099)
            mul99: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu99, split_199)
            lv503 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight3, model_layers_27_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims401: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv503, axes=None)
            matmul401: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul99, permute_dims401, out_dtype="void")
            add299: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul401, add298)
            rms_norm202: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add299, model_layers_28_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv504 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight3, model_layers_28_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims402: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv504, axes=None)
            matmul402: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm202, permute_dims402, out_dtype="void")
            add300: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul402, model_layers_28_self_attn_c_attn_bias3)
            reshape400: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add300, R.shape([1, seq_len, 20, 128]))
            reshape401: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape400, R.shape([seq_len, 20, 128]))
            lv505 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape401), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape402: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv505, R.shape([1, seq_len, 16, 128]))
            reshape403: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape402, R.shape([1, seq_len, 2048]))
            lv506 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight3, model_layers_28_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims403: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv506, axes=None)
            matmul403: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape403, permute_dims403, out_dtype="void")
            add301: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul403, add299)
            rms_norm203: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add301, model_layers_28_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv507 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight3, model_layers_28_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims404: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv507, axes=None)
            matmul404: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm203, permute_dims404, out_dtype="void")
            split100: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul404, indices_or_sections=2, axis=-1)
            split_0100: R.Tensor((1, seq_len, 11008), dtype="float16") = split100[0]
            split_1100: R.Tensor((1, seq_len, 11008), dtype="float16") = split100[1]
            silu100: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0100)
            mul100: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu100, split_1100)
            lv508 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight3, model_layers_28_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims405: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv508, axes=None)
            matmul405: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul100, permute_dims405, out_dtype="void")
            add302: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul405, add301)
            rms_norm204: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add302, model_layers_29_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv509 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight3, model_layers_29_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims406: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv509, axes=None)
            matmul406: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm204, permute_dims406, out_dtype="void")
            add303: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul406, model_layers_29_self_attn_c_attn_bias3)
            reshape404: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add303, R.shape([1, seq_len, 20, 128]))
            reshape405: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape404, R.shape([seq_len, 20, 128]))
            lv510 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape405), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape406: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv510, R.shape([1, seq_len, 16, 128]))
            reshape407: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape406, R.shape([1, seq_len, 2048]))
            lv511 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight3, model_layers_29_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims407: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv511, axes=None)
            matmul407: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape407, permute_dims407, out_dtype="void")
            add304: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul407, add302)
            rms_norm205: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add304, model_layers_29_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv512 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight3, model_layers_29_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims408: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv512, axes=None)
            matmul408: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm205, permute_dims408, out_dtype="void")
            split101: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul408, indices_or_sections=2, axis=-1)
            split_0101: R.Tensor((1, seq_len, 11008), dtype="float16") = split101[0]
            split_1101: R.Tensor((1, seq_len, 11008), dtype="float16") = split101[1]
            silu101: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0101)
            mul101: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu101, split_1101)
            lv513 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight3, model_layers_29_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims409: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv513, axes=None)
            matmul409: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul101, permute_dims409, out_dtype="void")
            add305: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul409, add304)
            rms_norm206: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add305, model_layers_30_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv514 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight3, model_layers_30_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims410: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv514, axes=None)
            matmul410: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm206, permute_dims410, out_dtype="void")
            add306: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul410, model_layers_30_self_attn_c_attn_bias3)
            reshape408: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add306, R.shape([1, seq_len, 20, 128]))
            reshape409: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape408, R.shape([seq_len, 20, 128]))
            lv515 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape409), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape410: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv515, R.shape([1, seq_len, 16, 128]))
            reshape411: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape410, R.shape([1, seq_len, 2048]))
            lv516 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight3, model_layers_30_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims411: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv516, axes=None)
            matmul411: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape411, permute_dims411, out_dtype="void")
            add307: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul411, add305)
            rms_norm207: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add307, model_layers_30_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv517 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight3, model_layers_30_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims412: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv517, axes=None)
            matmul412: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm207, permute_dims412, out_dtype="void")
            split102: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul412, indices_or_sections=2, axis=-1)
            split_0102: R.Tensor((1, seq_len, 11008), dtype="float16") = split102[0]
            split_1102: R.Tensor((1, seq_len, 11008), dtype="float16") = split102[1]
            silu102: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0102)
            mul102: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu102, split_1102)
            lv518 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight3, model_layers_30_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims413: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv518, axes=None)
            matmul413: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul102, permute_dims413, out_dtype="void")
            add308: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul413, add307)
            rms_norm208: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add308, model_layers_31_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv519 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight3, model_layers_31_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims414: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv519, axes=None)
            matmul414: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm208, permute_dims414, out_dtype="void")
            add309: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul414, model_layers_31_self_attn_c_attn_bias3)
            reshape412: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add309, R.shape([1, seq_len, 20, 128]))
            reshape413: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape412, R.shape([seq_len, 20, 128]))
            lv520 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape413), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape414: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv520, R.shape([1, seq_len, 16, 128]))
            reshape415: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape414, R.shape([1, seq_len, 2048]))
            lv521 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight3, model_layers_31_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims415: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv521, axes=None)
            matmul415: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape415, permute_dims415, out_dtype="void")
            add310: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul415, add308)
            rms_norm209: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add310, model_layers_31_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv522 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight3, model_layers_31_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims416: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv522, axes=None)
            matmul416: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm209, permute_dims416, out_dtype="void")
            split103: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul416, indices_or_sections=2, axis=-1)
            split_0103: R.Tensor((1, seq_len, 11008), dtype="float16") = split103[0]
            split_1103: R.Tensor((1, seq_len, 11008), dtype="float16") = split103[1]
            silu103: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0103)
            mul103: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu103, split_1103)
            lv523 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight3, model_layers_31_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims417: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv523, axes=None)
            matmul417: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul103, permute_dims417, out_dtype="void")
            add311: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul417, add310)
            rms_norm210: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add311, model_layers_32_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv524 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight3, model_layers_32_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims418: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv524, axes=None)
            matmul418: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm210, permute_dims418, out_dtype="void")
            add312: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul418, model_layers_32_self_attn_c_attn_bias3)
            reshape416: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add312, R.shape([1, seq_len, 20, 128]))
            reshape417: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape416, R.shape([seq_len, 20, 128]))
            lv525 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape417), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape418: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv525, R.shape([1, seq_len, 16, 128]))
            reshape419: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape418, R.shape([1, seq_len, 2048]))
            lv526 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight3, model_layers_32_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims419: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv526, axes=None)
            matmul419: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape419, permute_dims419, out_dtype="void")
            add313: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul419, add311)
            rms_norm211: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add313, model_layers_32_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv527 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight3, model_layers_32_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims420: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv527, axes=None)
            matmul420: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm211, permute_dims420, out_dtype="void")
            split104: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul420, indices_or_sections=2, axis=-1)
            split_0104: R.Tensor((1, seq_len, 11008), dtype="float16") = split104[0]
            split_1104: R.Tensor((1, seq_len, 11008), dtype="float16") = split104[1]
            silu104: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0104)
            mul104: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu104, split_1104)
            lv528 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight3, model_layers_32_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims421: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv528, axes=None)
            matmul421: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul104, permute_dims421, out_dtype="void")
            add314: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul421, add313)
            rms_norm212: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add314, model_layers_33_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv529 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight3, model_layers_33_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims422: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv529, axes=None)
            matmul422: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm212, permute_dims422, out_dtype="void")
            add315: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul422, model_layers_33_self_attn_c_attn_bias3)
            reshape420: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add315, R.shape([1, seq_len, 20, 128]))
            reshape421: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape420, R.shape([seq_len, 20, 128]))
            lv530 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape421), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape422: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv530, R.shape([1, seq_len, 16, 128]))
            reshape423: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape422, R.shape([1, seq_len, 2048]))
            lv531 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight3, model_layers_33_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims423: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv531, axes=None)
            matmul423: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape423, permute_dims423, out_dtype="void")
            add316: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul423, add314)
            rms_norm213: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add316, model_layers_33_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv532 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight3, model_layers_33_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims424: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv532, axes=None)
            matmul424: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm213, permute_dims424, out_dtype="void")
            split105: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul424, indices_or_sections=2, axis=-1)
            split_0105: R.Tensor((1, seq_len, 11008), dtype="float16") = split105[0]
            split_1105: R.Tensor((1, seq_len, 11008), dtype="float16") = split105[1]
            silu105: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0105)
            mul105: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu105, split_1105)
            lv533 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight3, model_layers_33_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims425: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv533, axes=None)
            matmul425: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul105, permute_dims425, out_dtype="void")
            add317: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul425, add316)
            rms_norm214: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add317, model_layers_34_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv534 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight3, model_layers_34_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims426: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv534, axes=None)
            matmul426: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm214, permute_dims426, out_dtype="void")
            add318: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul426, model_layers_34_self_attn_c_attn_bias3)
            reshape424: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add318, R.shape([1, seq_len, 20, 128]))
            reshape425: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape424, R.shape([seq_len, 20, 128]))
            lv535 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape425), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape426: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv535, R.shape([1, seq_len, 16, 128]))
            reshape427: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape426, R.shape([1, seq_len, 2048]))
            lv536 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight3, model_layers_34_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims427: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv536, axes=None)
            matmul427: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape427, permute_dims427, out_dtype="void")
            add319: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul427, add317)
            rms_norm215: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add319, model_layers_34_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv537 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight3, model_layers_34_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims428: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv537, axes=None)
            matmul428: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm215, permute_dims428, out_dtype="void")
            split106: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul428, indices_or_sections=2, axis=-1)
            split_0106: R.Tensor((1, seq_len, 11008), dtype="float16") = split106[0]
            split_1106: R.Tensor((1, seq_len, 11008), dtype="float16") = split106[1]
            silu106: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0106)
            mul106: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu106, split_1106)
            lv538 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight3, model_layers_34_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims429: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv538, axes=None)
            matmul429: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul106, permute_dims429, out_dtype="void")
            add320: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul429, add319)
            rms_norm216: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add320, model_layers_35_input_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv539 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight3, model_layers_35_self_attn_c_attn_q_scale3), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims430: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv539, axes=None)
            matmul430: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm216, permute_dims430, out_dtype="void")
            add321: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul430, model_layers_35_self_attn_c_attn_bias3)
            reshape428: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add321, R.shape([1, seq_len, 20, 128]))
            reshape429: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape428, R.shape([seq_len, 20, 128]))
            lv540 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape429), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape430: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv540, R.shape([1, seq_len, 16, 128]))
            reshape431: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape430, R.shape([1, seq_len, 2048]))
            lv541 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight3, model_layers_35_self_attn_o_proj_q_scale3), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims431: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv541, axes=None)
            matmul431: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape431, permute_dims431, out_dtype="void")
            add322: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul431, add320)
            rms_norm217: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add322, model_layers_35_post_attention_layernorm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            lv542 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight3, model_layers_35_mlp_gate_up_proj_q_scale3), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims432: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv542, axes=None)
            matmul432: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm217, permute_dims432, out_dtype="void")
            split107: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul432, indices_or_sections=2, axis=-1)
            split_0107: R.Tensor((1, seq_len, 11008), dtype="float16") = split107[0]
            split_1107: R.Tensor((1, seq_len, 11008), dtype="float16") = split107[1]
            silu107: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0107)
            mul107: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu107, split_1107)
            lv543 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight3, model_layers_35_mlp_down_proj_q_scale3), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims433: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv543, axes=None)
            matmul433: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul107, permute_dims433, out_dtype="void")
            add323: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul433, add322)
            rms_norm218: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add323, model_norm_weight3, axes=[-1], epsilon=9.9999999999999995e-07)
            take1: R.Tensor((1, batch_size, 2048), dtype="float16") = R.take(rms_norm218, logit_positions, axis=1)
            lv544 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight3, model_embed_tokens_q_scale3), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            permute_dims434: R.Tensor((2048, 151936), dtype="float16") = R.permute_dims(lv544, axes=None)
            matmul434: R.Tensor((1, batch_size, 151936), dtype="float32") = R.matmul(take1, permute_dims434, out_dtype="float32")
            gv3: R.Tuple(R.Tensor((1, batch_size, 151936), dtype="float32"), R.Object) = matmul434, paged_kv_cache
            R.output(gv3)
        return gv3

    @R.function
    def batch_verify(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", 151936), dtype="float32"), R.Object):
        seq_len = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight5: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale5: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm292: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(input_embeds, model_layers_0_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv726 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight5, model_layers_0_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims580: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv726, axes=None)
            matmul580: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm292, permute_dims580, out_dtype="void")
            add432: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul580, model_layers_0_self_attn_c_attn_bias5)
            reshape576: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add432, R.shape([1, seq_len, 20, 128]))
            reshape577: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape576, R.shape([seq_len, 20, 128]))
            lv727 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape577), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape578: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv727, R.shape([1, seq_len, 16, 128]))
            reshape579: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape578, R.shape([1, seq_len, 2048]))
            lv728 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight5, model_layers_0_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims581: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv728, axes=None)
            matmul581: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape579, permute_dims581, out_dtype="void")
            add433: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul581, input_embeds)
            rms_norm293: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add433, model_layers_0_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv729 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight5, model_layers_0_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims582: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv729, axes=None)
            matmul582: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm293, permute_dims582, out_dtype="void")
            split144: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul582, indices_or_sections=2, axis=-1)
            split_0144: R.Tensor((1, seq_len, 11008), dtype="float16") = split144[0]
            split_1144: R.Tensor((1, seq_len, 11008), dtype="float16") = split144[1]
            silu144: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0144)
            mul144: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu144, split_1144)
            lv730 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight5, model_layers_0_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims583: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv730, axes=None)
            matmul583: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul144, permute_dims583, out_dtype="void")
            add434: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul583, add433)
            rms_norm294: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add434, model_layers_1_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv731 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight5, model_layers_1_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims584: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv731, axes=None)
            matmul584: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm294, permute_dims584, out_dtype="void")
            add435: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul584, model_layers_1_self_attn_c_attn_bias5)
            reshape580: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add435, R.shape([1, seq_len, 20, 128]))
            reshape581: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape580, R.shape([seq_len, 20, 128]))
            lv732 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape581), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape582: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv732, R.shape([1, seq_len, 16, 128]))
            reshape583: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape582, R.shape([1, seq_len, 2048]))
            lv733 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight5, model_layers_1_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims585: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv733, axes=None)
            matmul585: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape583, permute_dims585, out_dtype="void")
            add436: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul585, add434)
            rms_norm295: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add436, model_layers_1_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv734 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight5, model_layers_1_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims586: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv734, axes=None)
            matmul586: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm295, permute_dims586, out_dtype="void")
            split145: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul586, indices_or_sections=2, axis=-1)
            split_0145: R.Tensor((1, seq_len, 11008), dtype="float16") = split145[0]
            split_1145: R.Tensor((1, seq_len, 11008), dtype="float16") = split145[1]
            silu145: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0145)
            mul145: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu145, split_1145)
            lv735 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight5, model_layers_1_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims587: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv735, axes=None)
            matmul587: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul145, permute_dims587, out_dtype="void")
            add437: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul587, add436)
            rms_norm296: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add437, model_layers_2_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv736 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight5, model_layers_2_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims588: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv736, axes=None)
            matmul588: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm296, permute_dims588, out_dtype="void")
            add438: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul588, model_layers_2_self_attn_c_attn_bias5)
            reshape584: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add438, R.shape([1, seq_len, 20, 128]))
            reshape585: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape584, R.shape([seq_len, 20, 128]))
            lv737 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape585), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape586: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv737, R.shape([1, seq_len, 16, 128]))
            reshape587: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape586, R.shape([1, seq_len, 2048]))
            lv738 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight5, model_layers_2_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims589: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv738, axes=None)
            matmul589: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape587, permute_dims589, out_dtype="void")
            add439: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul589, add437)
            rms_norm297: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add439, model_layers_2_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv739 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight5, model_layers_2_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims590: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv739, axes=None)
            matmul590: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm297, permute_dims590, out_dtype="void")
            split146: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul590, indices_or_sections=2, axis=-1)
            split_0146: R.Tensor((1, seq_len, 11008), dtype="float16") = split146[0]
            split_1146: R.Tensor((1, seq_len, 11008), dtype="float16") = split146[1]
            silu146: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0146)
            mul146: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu146, split_1146)
            lv740 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight5, model_layers_2_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims591: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv740, axes=None)
            matmul591: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul146, permute_dims591, out_dtype="void")
            add440: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul591, add439)
            rms_norm298: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add440, model_layers_3_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv741 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight5, model_layers_3_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims592: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv741, axes=None)
            matmul592: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm298, permute_dims592, out_dtype="void")
            add441: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul592, model_layers_3_self_attn_c_attn_bias5)
            reshape588: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add441, R.shape([1, seq_len, 20, 128]))
            reshape589: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape588, R.shape([seq_len, 20, 128]))
            lv742 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape589), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape590: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv742, R.shape([1, seq_len, 16, 128]))
            reshape591: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape590, R.shape([1, seq_len, 2048]))
            lv743 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight5, model_layers_3_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims593: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv743, axes=None)
            matmul593: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape591, permute_dims593, out_dtype="void")
            add442: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul593, add440)
            rms_norm299: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add442, model_layers_3_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv744 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight5, model_layers_3_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims594: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv744, axes=None)
            matmul594: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm299, permute_dims594, out_dtype="void")
            split147: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul594, indices_or_sections=2, axis=-1)
            split_0147: R.Tensor((1, seq_len, 11008), dtype="float16") = split147[0]
            split_1147: R.Tensor((1, seq_len, 11008), dtype="float16") = split147[1]
            silu147: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0147)
            mul147: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu147, split_1147)
            lv745 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight5, model_layers_3_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims595: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv745, axes=None)
            matmul595: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul147, permute_dims595, out_dtype="void")
            add443: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul595, add442)
            rms_norm300: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add443, model_layers_4_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv746 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight5, model_layers_4_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims596: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv746, axes=None)
            matmul596: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm300, permute_dims596, out_dtype="void")
            add444: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul596, model_layers_4_self_attn_c_attn_bias5)
            reshape592: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add444, R.shape([1, seq_len, 20, 128]))
            reshape593: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape592, R.shape([seq_len, 20, 128]))
            lv747 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape593), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape594: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv747, R.shape([1, seq_len, 16, 128]))
            reshape595: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape594, R.shape([1, seq_len, 2048]))
            lv748 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight5, model_layers_4_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims597: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv748, axes=None)
            matmul597: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape595, permute_dims597, out_dtype="void")
            add445: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul597, add443)
            rms_norm301: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add445, model_layers_4_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv749 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight5, model_layers_4_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims598: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv749, axes=None)
            matmul598: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm301, permute_dims598, out_dtype="void")
            split148: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul598, indices_or_sections=2, axis=-1)
            split_0148: R.Tensor((1, seq_len, 11008), dtype="float16") = split148[0]
            split_1148: R.Tensor((1, seq_len, 11008), dtype="float16") = split148[1]
            silu148: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0148)
            mul148: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu148, split_1148)
            lv750 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight5, model_layers_4_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims599: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv750, axes=None)
            matmul599: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul148, permute_dims599, out_dtype="void")
            add446: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul599, add445)
            rms_norm302: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add446, model_layers_5_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv751 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight5, model_layers_5_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims600: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv751, axes=None)
            matmul600: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm302, permute_dims600, out_dtype="void")
            add447: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul600, model_layers_5_self_attn_c_attn_bias5)
            reshape596: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add447, R.shape([1, seq_len, 20, 128]))
            reshape597: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape596, R.shape([seq_len, 20, 128]))
            lv752 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape597), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape598: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv752, R.shape([1, seq_len, 16, 128]))
            reshape599: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape598, R.shape([1, seq_len, 2048]))
            lv753 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight5, model_layers_5_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims601: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv753, axes=None)
            matmul601: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape599, permute_dims601, out_dtype="void")
            add448: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul601, add446)
            rms_norm303: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add448, model_layers_5_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv754 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight5, model_layers_5_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims602: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv754, axes=None)
            matmul602: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm303, permute_dims602, out_dtype="void")
            split149: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul602, indices_or_sections=2, axis=-1)
            split_0149: R.Tensor((1, seq_len, 11008), dtype="float16") = split149[0]
            split_1149: R.Tensor((1, seq_len, 11008), dtype="float16") = split149[1]
            silu149: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0149)
            mul149: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu149, split_1149)
            lv755 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight5, model_layers_5_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims603: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv755, axes=None)
            matmul603: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul149, permute_dims603, out_dtype="void")
            add449: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul603, add448)
            rms_norm304: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add449, model_layers_6_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv756 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight5, model_layers_6_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims604: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv756, axes=None)
            matmul604: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm304, permute_dims604, out_dtype="void")
            add450: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul604, model_layers_6_self_attn_c_attn_bias5)
            reshape600: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add450, R.shape([1, seq_len, 20, 128]))
            reshape601: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape600, R.shape([seq_len, 20, 128]))
            lv757 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape601), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape602: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv757, R.shape([1, seq_len, 16, 128]))
            reshape603: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape602, R.shape([1, seq_len, 2048]))
            lv758 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight5, model_layers_6_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims605: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv758, axes=None)
            matmul605: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape603, permute_dims605, out_dtype="void")
            add451: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul605, add449)
            rms_norm305: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add451, model_layers_6_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv759 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight5, model_layers_6_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims606: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv759, axes=None)
            matmul606: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm305, permute_dims606, out_dtype="void")
            split150: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul606, indices_or_sections=2, axis=-1)
            split_0150: R.Tensor((1, seq_len, 11008), dtype="float16") = split150[0]
            split_1150: R.Tensor((1, seq_len, 11008), dtype="float16") = split150[1]
            silu150: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0150)
            mul150: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu150, split_1150)
            lv760 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight5, model_layers_6_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims607: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv760, axes=None)
            matmul607: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul150, permute_dims607, out_dtype="void")
            add452: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul607, add451)
            rms_norm306: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add452, model_layers_7_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv761 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight5, model_layers_7_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims608: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv761, axes=None)
            matmul608: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm306, permute_dims608, out_dtype="void")
            add453: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul608, model_layers_7_self_attn_c_attn_bias5)
            reshape604: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add453, R.shape([1, seq_len, 20, 128]))
            reshape605: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape604, R.shape([seq_len, 20, 128]))
            lv762 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape605), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape606: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv762, R.shape([1, seq_len, 16, 128]))
            reshape607: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape606, R.shape([1, seq_len, 2048]))
            lv763 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight5, model_layers_7_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims609: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv763, axes=None)
            matmul609: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape607, permute_dims609, out_dtype="void")
            add454: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul609, add452)
            rms_norm307: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add454, model_layers_7_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv764 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight5, model_layers_7_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims610: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv764, axes=None)
            matmul610: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm307, permute_dims610, out_dtype="void")
            split151: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul610, indices_or_sections=2, axis=-1)
            split_0151: R.Tensor((1, seq_len, 11008), dtype="float16") = split151[0]
            split_1151: R.Tensor((1, seq_len, 11008), dtype="float16") = split151[1]
            silu151: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0151)
            mul151: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu151, split_1151)
            lv765 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight5, model_layers_7_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims611: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv765, axes=None)
            matmul611: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul151, permute_dims611, out_dtype="void")
            add455: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul611, add454)
            rms_norm308: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add455, model_layers_8_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv766 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight5, model_layers_8_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims612: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv766, axes=None)
            matmul612: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm308, permute_dims612, out_dtype="void")
            add456: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul612, model_layers_8_self_attn_c_attn_bias5)
            reshape608: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add456, R.shape([1, seq_len, 20, 128]))
            reshape609: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape608, R.shape([seq_len, 20, 128]))
            lv767 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape609), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape610: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv767, R.shape([1, seq_len, 16, 128]))
            reshape611: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape610, R.shape([1, seq_len, 2048]))
            lv768 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight5, model_layers_8_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims613: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv768, axes=None)
            matmul613: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape611, permute_dims613, out_dtype="void")
            add457: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul613, add455)
            rms_norm309: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add457, model_layers_8_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv769 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight5, model_layers_8_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims614: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv769, axes=None)
            matmul614: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm309, permute_dims614, out_dtype="void")
            split152: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul614, indices_or_sections=2, axis=-1)
            split_0152: R.Tensor((1, seq_len, 11008), dtype="float16") = split152[0]
            split_1152: R.Tensor((1, seq_len, 11008), dtype="float16") = split152[1]
            silu152: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0152)
            mul152: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu152, split_1152)
            lv770 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight5, model_layers_8_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims615: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv770, axes=None)
            matmul615: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul152, permute_dims615, out_dtype="void")
            add458: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul615, add457)
            rms_norm310: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add458, model_layers_9_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv771 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight5, model_layers_9_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims616: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv771, axes=None)
            matmul616: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm310, permute_dims616, out_dtype="void")
            add459: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul616, model_layers_9_self_attn_c_attn_bias5)
            reshape612: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add459, R.shape([1, seq_len, 20, 128]))
            reshape613: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape612, R.shape([seq_len, 20, 128]))
            lv772 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape613), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape614: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv772, R.shape([1, seq_len, 16, 128]))
            reshape615: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape614, R.shape([1, seq_len, 2048]))
            lv773 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight5, model_layers_9_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims617: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv773, axes=None)
            matmul617: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape615, permute_dims617, out_dtype="void")
            add460: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul617, add458)
            rms_norm311: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add460, model_layers_9_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv774 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight5, model_layers_9_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims618: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv774, axes=None)
            matmul618: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm311, permute_dims618, out_dtype="void")
            split153: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul618, indices_or_sections=2, axis=-1)
            split_0153: R.Tensor((1, seq_len, 11008), dtype="float16") = split153[0]
            split_1153: R.Tensor((1, seq_len, 11008), dtype="float16") = split153[1]
            silu153: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0153)
            mul153: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu153, split_1153)
            lv775 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight5, model_layers_9_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims619: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv775, axes=None)
            matmul619: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul153, permute_dims619, out_dtype="void")
            add461: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul619, add460)
            rms_norm312: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add461, model_layers_10_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv776 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight5, model_layers_10_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims620: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv776, axes=None)
            matmul620: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm312, permute_dims620, out_dtype="void")
            add462: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul620, model_layers_10_self_attn_c_attn_bias5)
            reshape616: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add462, R.shape([1, seq_len, 20, 128]))
            reshape617: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape616, R.shape([seq_len, 20, 128]))
            lv777 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape617), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape618: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv777, R.shape([1, seq_len, 16, 128]))
            reshape619: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape618, R.shape([1, seq_len, 2048]))
            lv778 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight5, model_layers_10_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims621: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv778, axes=None)
            matmul621: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape619, permute_dims621, out_dtype="void")
            add463: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul621, add461)
            rms_norm313: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add463, model_layers_10_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv779 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight5, model_layers_10_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims622: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv779, axes=None)
            matmul622: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm313, permute_dims622, out_dtype="void")
            split154: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul622, indices_or_sections=2, axis=-1)
            split_0154: R.Tensor((1, seq_len, 11008), dtype="float16") = split154[0]
            split_1154: R.Tensor((1, seq_len, 11008), dtype="float16") = split154[1]
            silu154: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0154)
            mul154: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu154, split_1154)
            lv780 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight5, model_layers_10_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims623: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv780, axes=None)
            matmul623: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul154, permute_dims623, out_dtype="void")
            add464: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul623, add463)
            rms_norm314: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add464, model_layers_11_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv781 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight5, model_layers_11_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims624: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv781, axes=None)
            matmul624: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm314, permute_dims624, out_dtype="void")
            add465: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul624, model_layers_11_self_attn_c_attn_bias5)
            reshape620: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add465, R.shape([1, seq_len, 20, 128]))
            reshape621: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape620, R.shape([seq_len, 20, 128]))
            lv782 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape621), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape622: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv782, R.shape([1, seq_len, 16, 128]))
            reshape623: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape622, R.shape([1, seq_len, 2048]))
            lv783 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight5, model_layers_11_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims625: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv783, axes=None)
            matmul625: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape623, permute_dims625, out_dtype="void")
            add466: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul625, add464)
            rms_norm315: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add466, model_layers_11_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv784 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight5, model_layers_11_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims626: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv784, axes=None)
            matmul626: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm315, permute_dims626, out_dtype="void")
            split155: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul626, indices_or_sections=2, axis=-1)
            split_0155: R.Tensor((1, seq_len, 11008), dtype="float16") = split155[0]
            split_1155: R.Tensor((1, seq_len, 11008), dtype="float16") = split155[1]
            silu155: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0155)
            mul155: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu155, split_1155)
            lv785 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight5, model_layers_11_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims627: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv785, axes=None)
            matmul627: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul155, permute_dims627, out_dtype="void")
            add467: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul627, add466)
            rms_norm316: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add467, model_layers_12_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv786 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight5, model_layers_12_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims628: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv786, axes=None)
            matmul628: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm316, permute_dims628, out_dtype="void")
            add468: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul628, model_layers_12_self_attn_c_attn_bias5)
            reshape624: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add468, R.shape([1, seq_len, 20, 128]))
            reshape625: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape624, R.shape([seq_len, 20, 128]))
            lv787 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape625), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape626: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv787, R.shape([1, seq_len, 16, 128]))
            reshape627: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape626, R.shape([1, seq_len, 2048]))
            lv788 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight5, model_layers_12_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims629: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv788, axes=None)
            matmul629: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape627, permute_dims629, out_dtype="void")
            add469: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul629, add467)
            rms_norm317: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add469, model_layers_12_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv789 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight5, model_layers_12_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims630: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv789, axes=None)
            matmul630: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm317, permute_dims630, out_dtype="void")
            split156: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul630, indices_or_sections=2, axis=-1)
            split_0156: R.Tensor((1, seq_len, 11008), dtype="float16") = split156[0]
            split_1156: R.Tensor((1, seq_len, 11008), dtype="float16") = split156[1]
            silu156: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0156)
            mul156: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu156, split_1156)
            lv790 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight5, model_layers_12_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims631: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv790, axes=None)
            matmul631: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul156, permute_dims631, out_dtype="void")
            add470: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul631, add469)
            rms_norm318: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add470, model_layers_13_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv791 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight5, model_layers_13_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims632: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv791, axes=None)
            matmul632: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm318, permute_dims632, out_dtype="void")
            add471: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul632, model_layers_13_self_attn_c_attn_bias5)
            reshape628: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add471, R.shape([1, seq_len, 20, 128]))
            reshape629: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape628, R.shape([seq_len, 20, 128]))
            lv792 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape629), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape630: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv792, R.shape([1, seq_len, 16, 128]))
            reshape631: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape630, R.shape([1, seq_len, 2048]))
            lv793 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight5, model_layers_13_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims633: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv793, axes=None)
            matmul633: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape631, permute_dims633, out_dtype="void")
            add472: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul633, add470)
            rms_norm319: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add472, model_layers_13_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv794 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight5, model_layers_13_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims634: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv794, axes=None)
            matmul634: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm319, permute_dims634, out_dtype="void")
            split157: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul634, indices_or_sections=2, axis=-1)
            split_0157: R.Tensor((1, seq_len, 11008), dtype="float16") = split157[0]
            split_1157: R.Tensor((1, seq_len, 11008), dtype="float16") = split157[1]
            silu157: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0157)
            mul157: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu157, split_1157)
            lv795 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight5, model_layers_13_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims635: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv795, axes=None)
            matmul635: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul157, permute_dims635, out_dtype="void")
            add473: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul635, add472)
            rms_norm320: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add473, model_layers_14_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv796 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight5, model_layers_14_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims636: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv796, axes=None)
            matmul636: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm320, permute_dims636, out_dtype="void")
            add474: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul636, model_layers_14_self_attn_c_attn_bias5)
            reshape632: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add474, R.shape([1, seq_len, 20, 128]))
            reshape633: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape632, R.shape([seq_len, 20, 128]))
            lv797 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape633), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape634: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv797, R.shape([1, seq_len, 16, 128]))
            reshape635: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape634, R.shape([1, seq_len, 2048]))
            lv798 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight5, model_layers_14_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims637: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv798, axes=None)
            matmul637: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape635, permute_dims637, out_dtype="void")
            add475: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul637, add473)
            rms_norm321: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add475, model_layers_14_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv799 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight5, model_layers_14_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims638: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv799, axes=None)
            matmul638: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm321, permute_dims638, out_dtype="void")
            split158: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul638, indices_or_sections=2, axis=-1)
            split_0158: R.Tensor((1, seq_len, 11008), dtype="float16") = split158[0]
            split_1158: R.Tensor((1, seq_len, 11008), dtype="float16") = split158[1]
            silu158: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0158)
            mul158: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu158, split_1158)
            lv800 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight5, model_layers_14_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims639: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv800, axes=None)
            matmul639: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul158, permute_dims639, out_dtype="void")
            add476: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul639, add475)
            rms_norm322: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add476, model_layers_15_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv801 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight5, model_layers_15_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims640: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv801, axes=None)
            matmul640: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm322, permute_dims640, out_dtype="void")
            add477: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul640, model_layers_15_self_attn_c_attn_bias5)
            reshape636: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add477, R.shape([1, seq_len, 20, 128]))
            reshape637: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape636, R.shape([seq_len, 20, 128]))
            lv802 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape637), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape638: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv802, R.shape([1, seq_len, 16, 128]))
            reshape639: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape638, R.shape([1, seq_len, 2048]))
            lv803 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight5, model_layers_15_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims641: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv803, axes=None)
            matmul641: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape639, permute_dims641, out_dtype="void")
            add478: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul641, add476)
            rms_norm323: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add478, model_layers_15_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv804 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight5, model_layers_15_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims642: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv804, axes=None)
            matmul642: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm323, permute_dims642, out_dtype="void")
            split159: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul642, indices_or_sections=2, axis=-1)
            split_0159: R.Tensor((1, seq_len, 11008), dtype="float16") = split159[0]
            split_1159: R.Tensor((1, seq_len, 11008), dtype="float16") = split159[1]
            silu159: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0159)
            mul159: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu159, split_1159)
            lv805 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight5, model_layers_15_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims643: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv805, axes=None)
            matmul643: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul159, permute_dims643, out_dtype="void")
            add479: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul643, add478)
            rms_norm324: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add479, model_layers_16_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv806 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight5, model_layers_16_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims644: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv806, axes=None)
            matmul644: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm324, permute_dims644, out_dtype="void")
            add480: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul644, model_layers_16_self_attn_c_attn_bias5)
            reshape640: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add480, R.shape([1, seq_len, 20, 128]))
            reshape641: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape640, R.shape([seq_len, 20, 128]))
            lv807 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape641), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape642: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv807, R.shape([1, seq_len, 16, 128]))
            reshape643: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape642, R.shape([1, seq_len, 2048]))
            lv808 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight5, model_layers_16_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims645: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv808, axes=None)
            matmul645: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape643, permute_dims645, out_dtype="void")
            add481: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul645, add479)
            rms_norm325: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add481, model_layers_16_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv809 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight5, model_layers_16_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims646: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv809, axes=None)
            matmul646: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm325, permute_dims646, out_dtype="void")
            split160: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul646, indices_or_sections=2, axis=-1)
            split_0160: R.Tensor((1, seq_len, 11008), dtype="float16") = split160[0]
            split_1160: R.Tensor((1, seq_len, 11008), dtype="float16") = split160[1]
            silu160: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0160)
            mul160: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu160, split_1160)
            lv810 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight5, model_layers_16_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims647: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv810, axes=None)
            matmul647: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul160, permute_dims647, out_dtype="void")
            add482: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul647, add481)
            rms_norm326: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add482, model_layers_17_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv811 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight5, model_layers_17_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims648: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv811, axes=None)
            matmul648: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm326, permute_dims648, out_dtype="void")
            add483: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul648, model_layers_17_self_attn_c_attn_bias5)
            reshape644: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add483, R.shape([1, seq_len, 20, 128]))
            reshape645: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape644, R.shape([seq_len, 20, 128]))
            lv812 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape645), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape646: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv812, R.shape([1, seq_len, 16, 128]))
            reshape647: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape646, R.shape([1, seq_len, 2048]))
            lv813 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight5, model_layers_17_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims649: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv813, axes=None)
            matmul649: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape647, permute_dims649, out_dtype="void")
            add484: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul649, add482)
            rms_norm327: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add484, model_layers_17_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv814 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight5, model_layers_17_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims650: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv814, axes=None)
            matmul650: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm327, permute_dims650, out_dtype="void")
            split161: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul650, indices_or_sections=2, axis=-1)
            split_0161: R.Tensor((1, seq_len, 11008), dtype="float16") = split161[0]
            split_1161: R.Tensor((1, seq_len, 11008), dtype="float16") = split161[1]
            silu161: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0161)
            mul161: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu161, split_1161)
            lv815 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight5, model_layers_17_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims651: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv815, axes=None)
            matmul651: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul161, permute_dims651, out_dtype="void")
            add485: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul651, add484)
            rms_norm328: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add485, model_layers_18_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv816 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight5, model_layers_18_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims652: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv816, axes=None)
            matmul652: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm328, permute_dims652, out_dtype="void")
            add486: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul652, model_layers_18_self_attn_c_attn_bias5)
            reshape648: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add486, R.shape([1, seq_len, 20, 128]))
            reshape649: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape648, R.shape([seq_len, 20, 128]))
            lv817 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape649), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape650: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv817, R.shape([1, seq_len, 16, 128]))
            reshape651: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape650, R.shape([1, seq_len, 2048]))
            lv818 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight5, model_layers_18_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims653: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv818, axes=None)
            matmul653: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape651, permute_dims653, out_dtype="void")
            add487: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul653, add485)
            rms_norm329: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add487, model_layers_18_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv819 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight5, model_layers_18_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims654: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv819, axes=None)
            matmul654: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm329, permute_dims654, out_dtype="void")
            split162: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul654, indices_or_sections=2, axis=-1)
            split_0162: R.Tensor((1, seq_len, 11008), dtype="float16") = split162[0]
            split_1162: R.Tensor((1, seq_len, 11008), dtype="float16") = split162[1]
            silu162: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0162)
            mul162: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu162, split_1162)
            lv820 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight5, model_layers_18_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims655: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv820, axes=None)
            matmul655: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul162, permute_dims655, out_dtype="void")
            add488: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul655, add487)
            rms_norm330: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add488, model_layers_19_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv821 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight5, model_layers_19_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims656: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv821, axes=None)
            matmul656: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm330, permute_dims656, out_dtype="void")
            add489: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul656, model_layers_19_self_attn_c_attn_bias5)
            reshape652: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add489, R.shape([1, seq_len, 20, 128]))
            reshape653: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape652, R.shape([seq_len, 20, 128]))
            lv822 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape653), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape654: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv822, R.shape([1, seq_len, 16, 128]))
            reshape655: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape654, R.shape([1, seq_len, 2048]))
            lv823 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight5, model_layers_19_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims657: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv823, axes=None)
            matmul657: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape655, permute_dims657, out_dtype="void")
            add490: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul657, add488)
            rms_norm331: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add490, model_layers_19_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv824 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight5, model_layers_19_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims658: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv824, axes=None)
            matmul658: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm331, permute_dims658, out_dtype="void")
            split163: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul658, indices_or_sections=2, axis=-1)
            split_0163: R.Tensor((1, seq_len, 11008), dtype="float16") = split163[0]
            split_1163: R.Tensor((1, seq_len, 11008), dtype="float16") = split163[1]
            silu163: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0163)
            mul163: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu163, split_1163)
            lv825 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight5, model_layers_19_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims659: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv825, axes=None)
            matmul659: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul163, permute_dims659, out_dtype="void")
            add491: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul659, add490)
            rms_norm332: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add491, model_layers_20_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv826 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight5, model_layers_20_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims660: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv826, axes=None)
            matmul660: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm332, permute_dims660, out_dtype="void")
            add492: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul660, model_layers_20_self_attn_c_attn_bias5)
            reshape656: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add492, R.shape([1, seq_len, 20, 128]))
            reshape657: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape656, R.shape([seq_len, 20, 128]))
            lv827 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape657), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape658: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv827, R.shape([1, seq_len, 16, 128]))
            reshape659: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape658, R.shape([1, seq_len, 2048]))
            lv828 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight5, model_layers_20_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims661: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv828, axes=None)
            matmul661: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape659, permute_dims661, out_dtype="void")
            add493: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul661, add491)
            rms_norm333: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add493, model_layers_20_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv829 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight5, model_layers_20_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims662: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv829, axes=None)
            matmul662: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm333, permute_dims662, out_dtype="void")
            split164: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul662, indices_or_sections=2, axis=-1)
            split_0164: R.Tensor((1, seq_len, 11008), dtype="float16") = split164[0]
            split_1164: R.Tensor((1, seq_len, 11008), dtype="float16") = split164[1]
            silu164: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0164)
            mul164: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu164, split_1164)
            lv830 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight5, model_layers_20_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims663: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv830, axes=None)
            matmul663: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul164, permute_dims663, out_dtype="void")
            add494: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul663, add493)
            rms_norm334: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add494, model_layers_21_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv831 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight5, model_layers_21_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims664: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv831, axes=None)
            matmul664: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm334, permute_dims664, out_dtype="void")
            add495: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul664, model_layers_21_self_attn_c_attn_bias5)
            reshape660: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add495, R.shape([1, seq_len, 20, 128]))
            reshape661: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape660, R.shape([seq_len, 20, 128]))
            lv832 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape661), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape662: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv832, R.shape([1, seq_len, 16, 128]))
            reshape663: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape662, R.shape([1, seq_len, 2048]))
            lv833 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight5, model_layers_21_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims665: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv833, axes=None)
            matmul665: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape663, permute_dims665, out_dtype="void")
            add496: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul665, add494)
            rms_norm335: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add496, model_layers_21_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv834 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight5, model_layers_21_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims666: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv834, axes=None)
            matmul666: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm335, permute_dims666, out_dtype="void")
            split165: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul666, indices_or_sections=2, axis=-1)
            split_0165: R.Tensor((1, seq_len, 11008), dtype="float16") = split165[0]
            split_1165: R.Tensor((1, seq_len, 11008), dtype="float16") = split165[1]
            silu165: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0165)
            mul165: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu165, split_1165)
            lv835 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight5, model_layers_21_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims667: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv835, axes=None)
            matmul667: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul165, permute_dims667, out_dtype="void")
            add497: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul667, add496)
            rms_norm336: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add497, model_layers_22_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv836 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight5, model_layers_22_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims668: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv836, axes=None)
            matmul668: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm336, permute_dims668, out_dtype="void")
            add498: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul668, model_layers_22_self_attn_c_attn_bias5)
            reshape664: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add498, R.shape([1, seq_len, 20, 128]))
            reshape665: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape664, R.shape([seq_len, 20, 128]))
            lv837 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape665), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape666: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv837, R.shape([1, seq_len, 16, 128]))
            reshape667: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape666, R.shape([1, seq_len, 2048]))
            lv838 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight5, model_layers_22_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims669: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv838, axes=None)
            matmul669: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape667, permute_dims669, out_dtype="void")
            add499: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul669, add497)
            rms_norm337: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add499, model_layers_22_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv839 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight5, model_layers_22_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims670: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv839, axes=None)
            matmul670: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm337, permute_dims670, out_dtype="void")
            split166: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul670, indices_or_sections=2, axis=-1)
            split_0166: R.Tensor((1, seq_len, 11008), dtype="float16") = split166[0]
            split_1166: R.Tensor((1, seq_len, 11008), dtype="float16") = split166[1]
            silu166: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0166)
            mul166: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu166, split_1166)
            lv840 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight5, model_layers_22_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims671: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv840, axes=None)
            matmul671: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul166, permute_dims671, out_dtype="void")
            add500: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul671, add499)
            rms_norm338: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add500, model_layers_23_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv841 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight5, model_layers_23_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims672: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv841, axes=None)
            matmul672: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm338, permute_dims672, out_dtype="void")
            add501: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul672, model_layers_23_self_attn_c_attn_bias5)
            reshape668: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add501, R.shape([1, seq_len, 20, 128]))
            reshape669: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape668, R.shape([seq_len, 20, 128]))
            lv842 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape669), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape670: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv842, R.shape([1, seq_len, 16, 128]))
            reshape671: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape670, R.shape([1, seq_len, 2048]))
            lv843 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight5, model_layers_23_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims673: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv843, axes=None)
            matmul673: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape671, permute_dims673, out_dtype="void")
            add502: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul673, add500)
            rms_norm339: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add502, model_layers_23_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv844 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight5, model_layers_23_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims674: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv844, axes=None)
            matmul674: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm339, permute_dims674, out_dtype="void")
            split167: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul674, indices_or_sections=2, axis=-1)
            split_0167: R.Tensor((1, seq_len, 11008), dtype="float16") = split167[0]
            split_1167: R.Tensor((1, seq_len, 11008), dtype="float16") = split167[1]
            silu167: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0167)
            mul167: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu167, split_1167)
            lv845 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight5, model_layers_23_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims675: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv845, axes=None)
            matmul675: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul167, permute_dims675, out_dtype="void")
            add503: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul675, add502)
            rms_norm340: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add503, model_layers_24_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv846 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight5, model_layers_24_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims676: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv846, axes=None)
            matmul676: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm340, permute_dims676, out_dtype="void")
            add504: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul676, model_layers_24_self_attn_c_attn_bias5)
            reshape672: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add504, R.shape([1, seq_len, 20, 128]))
            reshape673: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape672, R.shape([seq_len, 20, 128]))
            lv847 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape673), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape674: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv847, R.shape([1, seq_len, 16, 128]))
            reshape675: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape674, R.shape([1, seq_len, 2048]))
            lv848 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight5, model_layers_24_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims677: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv848, axes=None)
            matmul677: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape675, permute_dims677, out_dtype="void")
            add505: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul677, add503)
            rms_norm341: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add505, model_layers_24_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv849 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight5, model_layers_24_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims678: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv849, axes=None)
            matmul678: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm341, permute_dims678, out_dtype="void")
            split168: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul678, indices_or_sections=2, axis=-1)
            split_0168: R.Tensor((1, seq_len, 11008), dtype="float16") = split168[0]
            split_1168: R.Tensor((1, seq_len, 11008), dtype="float16") = split168[1]
            silu168: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0168)
            mul168: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu168, split_1168)
            lv850 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight5, model_layers_24_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims679: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv850, axes=None)
            matmul679: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul168, permute_dims679, out_dtype="void")
            add506: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul679, add505)
            rms_norm342: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add506, model_layers_25_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv851 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight5, model_layers_25_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims680: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv851, axes=None)
            matmul680: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm342, permute_dims680, out_dtype="void")
            add507: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul680, model_layers_25_self_attn_c_attn_bias5)
            reshape676: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add507, R.shape([1, seq_len, 20, 128]))
            reshape677: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape676, R.shape([seq_len, 20, 128]))
            lv852 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape677), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape678: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv852, R.shape([1, seq_len, 16, 128]))
            reshape679: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape678, R.shape([1, seq_len, 2048]))
            lv853 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight5, model_layers_25_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims681: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv853, axes=None)
            matmul681: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape679, permute_dims681, out_dtype="void")
            add508: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul681, add506)
            rms_norm343: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add508, model_layers_25_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv854 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight5, model_layers_25_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims682: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv854, axes=None)
            matmul682: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm343, permute_dims682, out_dtype="void")
            split169: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul682, indices_or_sections=2, axis=-1)
            split_0169: R.Tensor((1, seq_len, 11008), dtype="float16") = split169[0]
            split_1169: R.Tensor((1, seq_len, 11008), dtype="float16") = split169[1]
            silu169: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0169)
            mul169: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu169, split_1169)
            lv855 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight5, model_layers_25_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims683: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv855, axes=None)
            matmul683: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul169, permute_dims683, out_dtype="void")
            add509: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul683, add508)
            rms_norm344: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add509, model_layers_26_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv856 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight5, model_layers_26_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims684: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv856, axes=None)
            matmul684: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm344, permute_dims684, out_dtype="void")
            add510: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul684, model_layers_26_self_attn_c_attn_bias5)
            reshape680: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add510, R.shape([1, seq_len, 20, 128]))
            reshape681: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape680, R.shape([seq_len, 20, 128]))
            lv857 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape681), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape682: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv857, R.shape([1, seq_len, 16, 128]))
            reshape683: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape682, R.shape([1, seq_len, 2048]))
            lv858 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight5, model_layers_26_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims685: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv858, axes=None)
            matmul685: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape683, permute_dims685, out_dtype="void")
            add511: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul685, add509)
            rms_norm345: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add511, model_layers_26_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv859 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight5, model_layers_26_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims686: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv859, axes=None)
            matmul686: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm345, permute_dims686, out_dtype="void")
            split170: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul686, indices_or_sections=2, axis=-1)
            split_0170: R.Tensor((1, seq_len, 11008), dtype="float16") = split170[0]
            split_1170: R.Tensor((1, seq_len, 11008), dtype="float16") = split170[1]
            silu170: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0170)
            mul170: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu170, split_1170)
            lv860 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight5, model_layers_26_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims687: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv860, axes=None)
            matmul687: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul170, permute_dims687, out_dtype="void")
            add512: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul687, add511)
            rms_norm346: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add512, model_layers_27_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv861 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight5, model_layers_27_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims688: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv861, axes=None)
            matmul688: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm346, permute_dims688, out_dtype="void")
            add513: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul688, model_layers_27_self_attn_c_attn_bias5)
            reshape684: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add513, R.shape([1, seq_len, 20, 128]))
            reshape685: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape684, R.shape([seq_len, 20, 128]))
            lv862 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape685), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape686: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv862, R.shape([1, seq_len, 16, 128]))
            reshape687: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape686, R.shape([1, seq_len, 2048]))
            lv863 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight5, model_layers_27_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims689: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv863, axes=None)
            matmul689: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape687, permute_dims689, out_dtype="void")
            add514: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul689, add512)
            rms_norm347: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add514, model_layers_27_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv864 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight5, model_layers_27_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims690: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv864, axes=None)
            matmul690: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm347, permute_dims690, out_dtype="void")
            split171: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul690, indices_or_sections=2, axis=-1)
            split_0171: R.Tensor((1, seq_len, 11008), dtype="float16") = split171[0]
            split_1171: R.Tensor((1, seq_len, 11008), dtype="float16") = split171[1]
            silu171: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0171)
            mul171: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu171, split_1171)
            lv865 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight5, model_layers_27_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims691: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv865, axes=None)
            matmul691: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul171, permute_dims691, out_dtype="void")
            add515: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul691, add514)
            rms_norm348: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add515, model_layers_28_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv866 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight5, model_layers_28_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims692: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv866, axes=None)
            matmul692: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm348, permute_dims692, out_dtype="void")
            add516: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul692, model_layers_28_self_attn_c_attn_bias5)
            reshape688: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add516, R.shape([1, seq_len, 20, 128]))
            reshape689: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape688, R.shape([seq_len, 20, 128]))
            lv867 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape689), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape690: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv867, R.shape([1, seq_len, 16, 128]))
            reshape691: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape690, R.shape([1, seq_len, 2048]))
            lv868 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight5, model_layers_28_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims693: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv868, axes=None)
            matmul693: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape691, permute_dims693, out_dtype="void")
            add517: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul693, add515)
            rms_norm349: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add517, model_layers_28_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv869 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight5, model_layers_28_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims694: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv869, axes=None)
            matmul694: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm349, permute_dims694, out_dtype="void")
            split172: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul694, indices_or_sections=2, axis=-1)
            split_0172: R.Tensor((1, seq_len, 11008), dtype="float16") = split172[0]
            split_1172: R.Tensor((1, seq_len, 11008), dtype="float16") = split172[1]
            silu172: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0172)
            mul172: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu172, split_1172)
            lv870 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight5, model_layers_28_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims695: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv870, axes=None)
            matmul695: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul172, permute_dims695, out_dtype="void")
            add518: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul695, add517)
            rms_norm350: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add518, model_layers_29_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv871 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight5, model_layers_29_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims696: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv871, axes=None)
            matmul696: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm350, permute_dims696, out_dtype="void")
            add519: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul696, model_layers_29_self_attn_c_attn_bias5)
            reshape692: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add519, R.shape([1, seq_len, 20, 128]))
            reshape693: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape692, R.shape([seq_len, 20, 128]))
            lv872 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape693), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape694: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv872, R.shape([1, seq_len, 16, 128]))
            reshape695: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape694, R.shape([1, seq_len, 2048]))
            lv873 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight5, model_layers_29_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims697: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv873, axes=None)
            matmul697: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape695, permute_dims697, out_dtype="void")
            add520: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul697, add518)
            rms_norm351: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add520, model_layers_29_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv874 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight5, model_layers_29_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims698: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv874, axes=None)
            matmul698: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm351, permute_dims698, out_dtype="void")
            split173: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul698, indices_or_sections=2, axis=-1)
            split_0173: R.Tensor((1, seq_len, 11008), dtype="float16") = split173[0]
            split_1173: R.Tensor((1, seq_len, 11008), dtype="float16") = split173[1]
            silu173: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0173)
            mul173: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu173, split_1173)
            lv875 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight5, model_layers_29_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims699: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv875, axes=None)
            matmul699: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul173, permute_dims699, out_dtype="void")
            add521: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul699, add520)
            rms_norm352: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add521, model_layers_30_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv876 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight5, model_layers_30_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims700: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv876, axes=None)
            matmul700: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm352, permute_dims700, out_dtype="void")
            add522: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul700, model_layers_30_self_attn_c_attn_bias5)
            reshape696: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add522, R.shape([1, seq_len, 20, 128]))
            reshape697: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape696, R.shape([seq_len, 20, 128]))
            lv877 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape697), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape698: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv877, R.shape([1, seq_len, 16, 128]))
            reshape699: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape698, R.shape([1, seq_len, 2048]))
            lv878 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight5, model_layers_30_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims701: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv878, axes=None)
            matmul701: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape699, permute_dims701, out_dtype="void")
            add523: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul701, add521)
            rms_norm353: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add523, model_layers_30_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv879 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight5, model_layers_30_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims702: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv879, axes=None)
            matmul702: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm353, permute_dims702, out_dtype="void")
            split174: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul702, indices_or_sections=2, axis=-1)
            split_0174: R.Tensor((1, seq_len, 11008), dtype="float16") = split174[0]
            split_1174: R.Tensor((1, seq_len, 11008), dtype="float16") = split174[1]
            silu174: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0174)
            mul174: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu174, split_1174)
            lv880 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight5, model_layers_30_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims703: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv880, axes=None)
            matmul703: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul174, permute_dims703, out_dtype="void")
            add524: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul703, add523)
            rms_norm354: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add524, model_layers_31_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv881 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight5, model_layers_31_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims704: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv881, axes=None)
            matmul704: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm354, permute_dims704, out_dtype="void")
            add525: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul704, model_layers_31_self_attn_c_attn_bias5)
            reshape700: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add525, R.shape([1, seq_len, 20, 128]))
            reshape701: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape700, R.shape([seq_len, 20, 128]))
            lv882 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape701), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape702: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv882, R.shape([1, seq_len, 16, 128]))
            reshape703: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape702, R.shape([1, seq_len, 2048]))
            lv883 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight5, model_layers_31_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims705: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv883, axes=None)
            matmul705: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape703, permute_dims705, out_dtype="void")
            add526: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul705, add524)
            rms_norm355: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add526, model_layers_31_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv884 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight5, model_layers_31_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims706: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv884, axes=None)
            matmul706: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm355, permute_dims706, out_dtype="void")
            split175: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul706, indices_or_sections=2, axis=-1)
            split_0175: R.Tensor((1, seq_len, 11008), dtype="float16") = split175[0]
            split_1175: R.Tensor((1, seq_len, 11008), dtype="float16") = split175[1]
            silu175: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0175)
            mul175: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu175, split_1175)
            lv885 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight5, model_layers_31_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims707: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv885, axes=None)
            matmul707: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul175, permute_dims707, out_dtype="void")
            add527: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul707, add526)
            rms_norm356: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add527, model_layers_32_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv886 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight5, model_layers_32_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims708: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv886, axes=None)
            matmul708: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm356, permute_dims708, out_dtype="void")
            add528: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul708, model_layers_32_self_attn_c_attn_bias5)
            reshape704: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add528, R.shape([1, seq_len, 20, 128]))
            reshape705: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape704, R.shape([seq_len, 20, 128]))
            lv887 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape705), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape706: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv887, R.shape([1, seq_len, 16, 128]))
            reshape707: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape706, R.shape([1, seq_len, 2048]))
            lv888 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight5, model_layers_32_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims709: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv888, axes=None)
            matmul709: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape707, permute_dims709, out_dtype="void")
            add529: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul709, add527)
            rms_norm357: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add529, model_layers_32_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv889 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight5, model_layers_32_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims710: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv889, axes=None)
            matmul710: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm357, permute_dims710, out_dtype="void")
            split176: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul710, indices_or_sections=2, axis=-1)
            split_0176: R.Tensor((1, seq_len, 11008), dtype="float16") = split176[0]
            split_1176: R.Tensor((1, seq_len, 11008), dtype="float16") = split176[1]
            silu176: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0176)
            mul176: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu176, split_1176)
            lv890 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight5, model_layers_32_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims711: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv890, axes=None)
            matmul711: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul176, permute_dims711, out_dtype="void")
            add530: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul711, add529)
            rms_norm358: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add530, model_layers_33_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv891 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight5, model_layers_33_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims712: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv891, axes=None)
            matmul712: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm358, permute_dims712, out_dtype="void")
            add531: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul712, model_layers_33_self_attn_c_attn_bias5)
            reshape708: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add531, R.shape([1, seq_len, 20, 128]))
            reshape709: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape708, R.shape([seq_len, 20, 128]))
            lv892 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape709), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape710: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv892, R.shape([1, seq_len, 16, 128]))
            reshape711: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape710, R.shape([1, seq_len, 2048]))
            lv893 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight5, model_layers_33_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims713: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv893, axes=None)
            matmul713: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape711, permute_dims713, out_dtype="void")
            add532: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul713, add530)
            rms_norm359: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add532, model_layers_33_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv894 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight5, model_layers_33_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims714: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv894, axes=None)
            matmul714: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm359, permute_dims714, out_dtype="void")
            split177: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul714, indices_or_sections=2, axis=-1)
            split_0177: R.Tensor((1, seq_len, 11008), dtype="float16") = split177[0]
            split_1177: R.Tensor((1, seq_len, 11008), dtype="float16") = split177[1]
            silu177: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0177)
            mul177: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu177, split_1177)
            lv895 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight5, model_layers_33_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims715: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv895, axes=None)
            matmul715: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul177, permute_dims715, out_dtype="void")
            add533: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul715, add532)
            rms_norm360: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add533, model_layers_34_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv896 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight5, model_layers_34_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims716: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv896, axes=None)
            matmul716: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm360, permute_dims716, out_dtype="void")
            add534: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul716, model_layers_34_self_attn_c_attn_bias5)
            reshape712: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add534, R.shape([1, seq_len, 20, 128]))
            reshape713: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape712, R.shape([seq_len, 20, 128]))
            lv897 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape713), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape714: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv897, R.shape([1, seq_len, 16, 128]))
            reshape715: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape714, R.shape([1, seq_len, 2048]))
            lv898 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight5, model_layers_34_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims717: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv898, axes=None)
            matmul717: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape715, permute_dims717, out_dtype="void")
            add535: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul717, add533)
            rms_norm361: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add535, model_layers_34_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv899 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight5, model_layers_34_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims718: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv899, axes=None)
            matmul718: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm361, permute_dims718, out_dtype="void")
            split178: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul718, indices_or_sections=2, axis=-1)
            split_0178: R.Tensor((1, seq_len, 11008), dtype="float16") = split178[0]
            split_1178: R.Tensor((1, seq_len, 11008), dtype="float16") = split178[1]
            silu178: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0178)
            mul178: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu178, split_1178)
            lv900 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight5, model_layers_34_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims719: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv900, axes=None)
            matmul719: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul178, permute_dims719, out_dtype="void")
            add536: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul719, add535)
            rms_norm362: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add536, model_layers_35_input_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv901 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight5, model_layers_35_self_attn_c_attn_q_scale5), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims720: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv901, axes=None)
            matmul720: R.Tensor((1, seq_len, 2560), dtype="float16") = R.matmul(rms_norm362, permute_dims720, out_dtype="void")
            add537: R.Tensor((1, seq_len, 2560), dtype="float16") = R.add(matmul720, model_layers_35_self_attn_c_attn_bias5)
            reshape716: R.Tensor((1, seq_len, 20, 128), dtype="float16") = R.reshape(add537, R.shape([1, seq_len, 20, 128]))
            reshape717: R.Tensor((seq_len, 20, 128), dtype="float16") = R.reshape(reshape716, R.shape([seq_len, 20, 128]))
            lv902 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape717), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape718: R.Tensor((1, seq_len, 16, 128), dtype="float16") = R.reshape(lv902, R.shape([1, seq_len, 16, 128]))
            reshape719: R.Tensor((1, seq_len, 2048), dtype="float16") = R.reshape(reshape718, R.shape([1, seq_len, 2048]))
            lv903 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight5, model_layers_35_self_attn_o_proj_q_scale5), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims721: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv903, axes=None)
            matmul721: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(reshape719, permute_dims721, out_dtype="void")
            add538: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul721, add536)
            rms_norm363: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add538, model_layers_35_post_attention_layernorm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv904 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight5, model_layers_35_mlp_gate_up_proj_q_scale5), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims722: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv904, axes=None)
            matmul722: R.Tensor((1, seq_len, 22016), dtype="float16") = R.matmul(rms_norm363, permute_dims722, out_dtype="void")
            split179: R.Tuple(R.Tensor((1, seq_len, 11008), dtype="float16"), R.Tensor((1, seq_len, 11008), dtype="float16")) = R.split(matmul722, indices_or_sections=2, axis=-1)
            split_0179: R.Tensor((1, seq_len, 11008), dtype="float16") = split179[0]
            split_1179: R.Tensor((1, seq_len, 11008), dtype="float16") = split179[1]
            silu179: R.Tensor((1, seq_len, 11008), dtype="float16") = R.nn.silu(split_0179)
            mul179: R.Tensor((1, seq_len, 11008), dtype="float16") = R.multiply(silu179, split_1179)
            lv905 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight5, model_layers_35_mlp_down_proj_q_scale5), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims723: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv905, axes=None)
            matmul723: R.Tensor((1, seq_len, 2048), dtype="float16") = R.matmul(mul179, permute_dims723, out_dtype="void")
            add539: R.Tensor((1, seq_len, 2048), dtype="float16") = R.add(matmul723, add538)
            rms_norm364: R.Tensor((1, seq_len, 2048), dtype="float16") = R.nn.rms_norm(add539, model_norm_weight5, axes=[-1], epsilon=9.9999999999999995e-07)
            lv906 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight5, model_embed_tokens_q_scale5), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            permute_dims724: R.Tensor((2048, 151936), dtype="float16") = R.permute_dims(lv906, axes=None)
            matmul724: R.Tensor((1, seq_len, 151936), dtype="float32") = R.matmul(rms_norm364, permute_dims724, out_dtype="float32")
            gv5: R.Tuple(R.Tensor((1, seq_len, 151936), dtype="float32"), R.Object) = matmul724, paged_kv_cache
            R.output(gv5)
        return gv5

    @R.function
    def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object:
        max_batch_size = T.int64()
        max_total_seq_len = T.int64()
        prefill_chunk_size = T.int64()
        page_size = T.int64()
        support_sliding_window = T.int64()
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        gv: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16")
        paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 36]), R.prim_value(16), R.prim_value(2), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(T.float32(1000000.0)), gv, cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, R.prim_value(0), R.prim_value(0), sinfo_args=(R.Object,))
        return paged_kv_cache

    @R.function
    def decode(input_embed: R.Tensor((1, 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 151936), dtype="float32"), R.Object):
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight2: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale2: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm73: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(input_embed, model_layers_0_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv183 = R.call_tir(cls.dequantize1, (model_layers_0_self_attn_c_attn_q_weight2, model_layers_0_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims145: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv183, axes=None)
            matmul145: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm73, permute_dims145, out_dtype="void")
            add108: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul145, model_layers_0_self_attn_c_attn_bias2)
            reshape144: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add108, R.shape([1, 1, 20, 128]))
            reshape145: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape144, R.shape([1, 20, 128]))
            lv184 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape145), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape146: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv184, R.shape([1, 1, 16, 128]))
            reshape147: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape146, R.shape([1, 1, 2048]))
            lv185 = R.call_tir(cls.dequantize2, (model_layers_0_self_attn_o_proj_q_weight2, model_layers_0_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims146: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv185, axes=None)
            matmul146: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape147, permute_dims146, out_dtype="void")
            add109: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul146, input_embed)
            rms_norm74: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add109, model_layers_0_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv186 = R.call_tir(cls.dequantize3, (model_layers_0_mlp_gate_up_proj_q_weight2, model_layers_0_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims147: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv186, axes=None)
            matmul147: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm74, permute_dims147, out_dtype="void")
            split36: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul147, indices_or_sections=2, axis=-1)
            split_036: R.Tensor((1, 1, 11008), dtype="float16") = split36[0]
            split_136: R.Tensor((1, 1, 11008), dtype="float16") = split36[1]
            silu36: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_036)
            mul36: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu36, split_136)
            lv187 = R.call_tir(cls.dequantize4, (model_layers_0_mlp_down_proj_q_weight2, model_layers_0_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims148: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv187, axes=None)
            matmul148: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul36, permute_dims148, out_dtype="void")
            add110: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul148, add109)
            rms_norm75: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add110, model_layers_1_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv188 = R.call_tir(cls.dequantize1, (model_layers_1_self_attn_c_attn_q_weight2, model_layers_1_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims149: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv188, axes=None)
            matmul149: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm75, permute_dims149, out_dtype="void")
            add111: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul149, model_layers_1_self_attn_c_attn_bias2)
            reshape148: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add111, R.shape([1, 1, 20, 128]))
            reshape149: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape148, R.shape([1, 20, 128]))
            lv189 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape149), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape150: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv189, R.shape([1, 1, 16, 128]))
            reshape151: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape150, R.shape([1, 1, 2048]))
            lv190 = R.call_tir(cls.dequantize2, (model_layers_1_self_attn_o_proj_q_weight2, model_layers_1_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims150: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv190, axes=None)
            matmul150: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape151, permute_dims150, out_dtype="void")
            add112: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul150, add110)
            rms_norm76: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add112, model_layers_1_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv191 = R.call_tir(cls.dequantize3, (model_layers_1_mlp_gate_up_proj_q_weight2, model_layers_1_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims151: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv191, axes=None)
            matmul151: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm76, permute_dims151, out_dtype="void")
            split37: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul151, indices_or_sections=2, axis=-1)
            split_037: R.Tensor((1, 1, 11008), dtype="float16") = split37[0]
            split_137: R.Tensor((1, 1, 11008), dtype="float16") = split37[1]
            silu37: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_037)
            mul37: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu37, split_137)
            lv192 = R.call_tir(cls.dequantize4, (model_layers_1_mlp_down_proj_q_weight2, model_layers_1_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims152: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv192, axes=None)
            matmul152: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul37, permute_dims152, out_dtype="void")
            add113: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul152, add112)
            rms_norm77: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add113, model_layers_2_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv193 = R.call_tir(cls.dequantize1, (model_layers_2_self_attn_c_attn_q_weight2, model_layers_2_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims153: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv193, axes=None)
            matmul153: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm77, permute_dims153, out_dtype="void")
            add114: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul153, model_layers_2_self_attn_c_attn_bias2)
            reshape152: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add114, R.shape([1, 1, 20, 128]))
            reshape153: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape152, R.shape([1, 20, 128]))
            lv194 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape153), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape154: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv194, R.shape([1, 1, 16, 128]))
            reshape155: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape154, R.shape([1, 1, 2048]))
            lv195 = R.call_tir(cls.dequantize2, (model_layers_2_self_attn_o_proj_q_weight2, model_layers_2_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims154: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv195, axes=None)
            matmul154: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape155, permute_dims154, out_dtype="void")
            add115: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul154, add113)
            rms_norm78: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add115, model_layers_2_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv196 = R.call_tir(cls.dequantize3, (model_layers_2_mlp_gate_up_proj_q_weight2, model_layers_2_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims155: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv196, axes=None)
            matmul155: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm78, permute_dims155, out_dtype="void")
            split38: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul155, indices_or_sections=2, axis=-1)
            split_038: R.Tensor((1, 1, 11008), dtype="float16") = split38[0]
            split_138: R.Tensor((1, 1, 11008), dtype="float16") = split38[1]
            silu38: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_038)
            mul38: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu38, split_138)
            lv197 = R.call_tir(cls.dequantize4, (model_layers_2_mlp_down_proj_q_weight2, model_layers_2_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims156: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv197, axes=None)
            matmul156: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul38, permute_dims156, out_dtype="void")
            add116: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul156, add115)
            rms_norm79: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add116, model_layers_3_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv198 = R.call_tir(cls.dequantize1, (model_layers_3_self_attn_c_attn_q_weight2, model_layers_3_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims157: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv198, axes=None)
            matmul157: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm79, permute_dims157, out_dtype="void")
            add117: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul157, model_layers_3_self_attn_c_attn_bias2)
            reshape156: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add117, R.shape([1, 1, 20, 128]))
            reshape157: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape156, R.shape([1, 20, 128]))
            lv199 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape157), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape158: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv199, R.shape([1, 1, 16, 128]))
            reshape159: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape158, R.shape([1, 1, 2048]))
            lv200 = R.call_tir(cls.dequantize2, (model_layers_3_self_attn_o_proj_q_weight2, model_layers_3_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims158: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv200, axes=None)
            matmul158: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape159, permute_dims158, out_dtype="void")
            add118: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul158, add116)
            rms_norm80: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add118, model_layers_3_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv201 = R.call_tir(cls.dequantize3, (model_layers_3_mlp_gate_up_proj_q_weight2, model_layers_3_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims159: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv201, axes=None)
            matmul159: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm80, permute_dims159, out_dtype="void")
            split39: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul159, indices_or_sections=2, axis=-1)
            split_039: R.Tensor((1, 1, 11008), dtype="float16") = split39[0]
            split_139: R.Tensor((1, 1, 11008), dtype="float16") = split39[1]
            silu39: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_039)
            mul39: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu39, split_139)
            lv202 = R.call_tir(cls.dequantize4, (model_layers_3_mlp_down_proj_q_weight2, model_layers_3_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims160: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv202, axes=None)
            matmul160: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul39, permute_dims160, out_dtype="void")
            add119: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul160, add118)
            rms_norm81: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add119, model_layers_4_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv203 = R.call_tir(cls.dequantize1, (model_layers_4_self_attn_c_attn_q_weight2, model_layers_4_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims161: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv203, axes=None)
            matmul161: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm81, permute_dims161, out_dtype="void")
            add120: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul161, model_layers_4_self_attn_c_attn_bias2)
            reshape160: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add120, R.shape([1, 1, 20, 128]))
            reshape161: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape160, R.shape([1, 20, 128]))
            lv204 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape161), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape162: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv204, R.shape([1, 1, 16, 128]))
            reshape163: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape162, R.shape([1, 1, 2048]))
            lv205 = R.call_tir(cls.dequantize2, (model_layers_4_self_attn_o_proj_q_weight2, model_layers_4_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims162: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv205, axes=None)
            matmul162: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape163, permute_dims162, out_dtype="void")
            add121: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul162, add119)
            rms_norm82: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add121, model_layers_4_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv206 = R.call_tir(cls.dequantize3, (model_layers_4_mlp_gate_up_proj_q_weight2, model_layers_4_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims163: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv206, axes=None)
            matmul163: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm82, permute_dims163, out_dtype="void")
            split40: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul163, indices_or_sections=2, axis=-1)
            split_040: R.Tensor((1, 1, 11008), dtype="float16") = split40[0]
            split_140: R.Tensor((1, 1, 11008), dtype="float16") = split40[1]
            silu40: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_040)
            mul40: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu40, split_140)
            lv207 = R.call_tir(cls.dequantize4, (model_layers_4_mlp_down_proj_q_weight2, model_layers_4_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims164: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv207, axes=None)
            matmul164: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul40, permute_dims164, out_dtype="void")
            add122: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul164, add121)
            rms_norm83: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add122, model_layers_5_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv208 = R.call_tir(cls.dequantize1, (model_layers_5_self_attn_c_attn_q_weight2, model_layers_5_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims165: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv208, axes=None)
            matmul165: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm83, permute_dims165, out_dtype="void")
            add123: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul165, model_layers_5_self_attn_c_attn_bias2)
            reshape164: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add123, R.shape([1, 1, 20, 128]))
            reshape165: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape164, R.shape([1, 20, 128]))
            lv209 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape165), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape166: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv209, R.shape([1, 1, 16, 128]))
            reshape167: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape166, R.shape([1, 1, 2048]))
            lv210 = R.call_tir(cls.dequantize2, (model_layers_5_self_attn_o_proj_q_weight2, model_layers_5_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims166: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv210, axes=None)
            matmul166: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape167, permute_dims166, out_dtype="void")
            add124: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul166, add122)
            rms_norm84: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add124, model_layers_5_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv211 = R.call_tir(cls.dequantize3, (model_layers_5_mlp_gate_up_proj_q_weight2, model_layers_5_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims167: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv211, axes=None)
            matmul167: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm84, permute_dims167, out_dtype="void")
            split41: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul167, indices_or_sections=2, axis=-1)
            split_041: R.Tensor((1, 1, 11008), dtype="float16") = split41[0]
            split_141: R.Tensor((1, 1, 11008), dtype="float16") = split41[1]
            silu41: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_041)
            mul41: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu41, split_141)
            lv212 = R.call_tir(cls.dequantize4, (model_layers_5_mlp_down_proj_q_weight2, model_layers_5_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims168: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv212, axes=None)
            matmul168: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul41, permute_dims168, out_dtype="void")
            add125: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul168, add124)
            rms_norm85: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add125, model_layers_6_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv213 = R.call_tir(cls.dequantize1, (model_layers_6_self_attn_c_attn_q_weight2, model_layers_6_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims169: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv213, axes=None)
            matmul169: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm85, permute_dims169, out_dtype="void")
            add126: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul169, model_layers_6_self_attn_c_attn_bias2)
            reshape168: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add126, R.shape([1, 1, 20, 128]))
            reshape169: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape168, R.shape([1, 20, 128]))
            lv214 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape169), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape170: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv214, R.shape([1, 1, 16, 128]))
            reshape171: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape170, R.shape([1, 1, 2048]))
            lv215 = R.call_tir(cls.dequantize2, (model_layers_6_self_attn_o_proj_q_weight2, model_layers_6_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims170: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv215, axes=None)
            matmul170: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape171, permute_dims170, out_dtype="void")
            add127: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul170, add125)
            rms_norm86: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add127, model_layers_6_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv216 = R.call_tir(cls.dequantize3, (model_layers_6_mlp_gate_up_proj_q_weight2, model_layers_6_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims171: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv216, axes=None)
            matmul171: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm86, permute_dims171, out_dtype="void")
            split42: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul171, indices_or_sections=2, axis=-1)
            split_042: R.Tensor((1, 1, 11008), dtype="float16") = split42[0]
            split_142: R.Tensor((1, 1, 11008), dtype="float16") = split42[1]
            silu42: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_042)
            mul42: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu42, split_142)
            lv217 = R.call_tir(cls.dequantize4, (model_layers_6_mlp_down_proj_q_weight2, model_layers_6_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims172: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv217, axes=None)
            matmul172: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul42, permute_dims172, out_dtype="void")
            add128: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul172, add127)
            rms_norm87: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add128, model_layers_7_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv218 = R.call_tir(cls.dequantize1, (model_layers_7_self_attn_c_attn_q_weight2, model_layers_7_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims173: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv218, axes=None)
            matmul173: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm87, permute_dims173, out_dtype="void")
            add129: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul173, model_layers_7_self_attn_c_attn_bias2)
            reshape172: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add129, R.shape([1, 1, 20, 128]))
            reshape173: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape172, R.shape([1, 20, 128]))
            lv219 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape173), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape174: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv219, R.shape([1, 1, 16, 128]))
            reshape175: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape174, R.shape([1, 1, 2048]))
            lv220 = R.call_tir(cls.dequantize2, (model_layers_7_self_attn_o_proj_q_weight2, model_layers_7_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims174: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv220, axes=None)
            matmul174: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape175, permute_dims174, out_dtype="void")
            add130: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul174, add128)
            rms_norm88: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add130, model_layers_7_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv221 = R.call_tir(cls.dequantize3, (model_layers_7_mlp_gate_up_proj_q_weight2, model_layers_7_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims175: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv221, axes=None)
            matmul175: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm88, permute_dims175, out_dtype="void")
            split43: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul175, indices_or_sections=2, axis=-1)
            split_043: R.Tensor((1, 1, 11008), dtype="float16") = split43[0]
            split_143: R.Tensor((1, 1, 11008), dtype="float16") = split43[1]
            silu43: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_043)
            mul43: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu43, split_143)
            lv222 = R.call_tir(cls.dequantize4, (model_layers_7_mlp_down_proj_q_weight2, model_layers_7_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims176: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv222, axes=None)
            matmul176: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul43, permute_dims176, out_dtype="void")
            add131: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul176, add130)
            rms_norm89: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add131, model_layers_8_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv223 = R.call_tir(cls.dequantize1, (model_layers_8_self_attn_c_attn_q_weight2, model_layers_8_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims177: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv223, axes=None)
            matmul177: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm89, permute_dims177, out_dtype="void")
            add132: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul177, model_layers_8_self_attn_c_attn_bias2)
            reshape176: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add132, R.shape([1, 1, 20, 128]))
            reshape177: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape176, R.shape([1, 20, 128]))
            lv224 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape177), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape178: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv224, R.shape([1, 1, 16, 128]))
            reshape179: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape178, R.shape([1, 1, 2048]))
            lv225 = R.call_tir(cls.dequantize2, (model_layers_8_self_attn_o_proj_q_weight2, model_layers_8_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims178: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv225, axes=None)
            matmul178: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape179, permute_dims178, out_dtype="void")
            add133: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul178, add131)
            rms_norm90: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add133, model_layers_8_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv226 = R.call_tir(cls.dequantize3, (model_layers_8_mlp_gate_up_proj_q_weight2, model_layers_8_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims179: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv226, axes=None)
            matmul179: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm90, permute_dims179, out_dtype="void")
            split44: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul179, indices_or_sections=2, axis=-1)
            split_044: R.Tensor((1, 1, 11008), dtype="float16") = split44[0]
            split_144: R.Tensor((1, 1, 11008), dtype="float16") = split44[1]
            silu44: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_044)
            mul44: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu44, split_144)
            lv227 = R.call_tir(cls.dequantize4, (model_layers_8_mlp_down_proj_q_weight2, model_layers_8_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims180: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv227, axes=None)
            matmul180: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul44, permute_dims180, out_dtype="void")
            add134: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul180, add133)
            rms_norm91: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add134, model_layers_9_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv228 = R.call_tir(cls.dequantize1, (model_layers_9_self_attn_c_attn_q_weight2, model_layers_9_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims181: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv228, axes=None)
            matmul181: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm91, permute_dims181, out_dtype="void")
            add135: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul181, model_layers_9_self_attn_c_attn_bias2)
            reshape180: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add135, R.shape([1, 1, 20, 128]))
            reshape181: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape180, R.shape([1, 20, 128]))
            lv229 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape181), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape182: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv229, R.shape([1, 1, 16, 128]))
            reshape183: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape182, R.shape([1, 1, 2048]))
            lv230 = R.call_tir(cls.dequantize2, (model_layers_9_self_attn_o_proj_q_weight2, model_layers_9_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims182: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv230, axes=None)
            matmul182: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape183, permute_dims182, out_dtype="void")
            add136: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul182, add134)
            rms_norm92: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add136, model_layers_9_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv231 = R.call_tir(cls.dequantize3, (model_layers_9_mlp_gate_up_proj_q_weight2, model_layers_9_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims183: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv231, axes=None)
            matmul183: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm92, permute_dims183, out_dtype="void")
            split45: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul183, indices_or_sections=2, axis=-1)
            split_045: R.Tensor((1, 1, 11008), dtype="float16") = split45[0]
            split_145: R.Tensor((1, 1, 11008), dtype="float16") = split45[1]
            silu45: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_045)
            mul45: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu45, split_145)
            lv232 = R.call_tir(cls.dequantize4, (model_layers_9_mlp_down_proj_q_weight2, model_layers_9_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims184: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv232, axes=None)
            matmul184: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul45, permute_dims184, out_dtype="void")
            add137: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul184, add136)
            rms_norm93: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add137, model_layers_10_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv233 = R.call_tir(cls.dequantize1, (model_layers_10_self_attn_c_attn_q_weight2, model_layers_10_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims185: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv233, axes=None)
            matmul185: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm93, permute_dims185, out_dtype="void")
            add138: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul185, model_layers_10_self_attn_c_attn_bias2)
            reshape184: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add138, R.shape([1, 1, 20, 128]))
            reshape185: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape184, R.shape([1, 20, 128]))
            lv234 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape185), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape186: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv234, R.shape([1, 1, 16, 128]))
            reshape187: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape186, R.shape([1, 1, 2048]))
            lv235 = R.call_tir(cls.dequantize2, (model_layers_10_self_attn_o_proj_q_weight2, model_layers_10_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims186: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv235, axes=None)
            matmul186: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape187, permute_dims186, out_dtype="void")
            add139: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul186, add137)
            rms_norm94: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add139, model_layers_10_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv236 = R.call_tir(cls.dequantize3, (model_layers_10_mlp_gate_up_proj_q_weight2, model_layers_10_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims187: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv236, axes=None)
            matmul187: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm94, permute_dims187, out_dtype="void")
            split46: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul187, indices_or_sections=2, axis=-1)
            split_046: R.Tensor((1, 1, 11008), dtype="float16") = split46[0]
            split_146: R.Tensor((1, 1, 11008), dtype="float16") = split46[1]
            silu46: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_046)
            mul46: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu46, split_146)
            lv237 = R.call_tir(cls.dequantize4, (model_layers_10_mlp_down_proj_q_weight2, model_layers_10_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims188: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv237, axes=None)
            matmul188: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul46, permute_dims188, out_dtype="void")
            add140: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul188, add139)
            rms_norm95: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add140, model_layers_11_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv238 = R.call_tir(cls.dequantize1, (model_layers_11_self_attn_c_attn_q_weight2, model_layers_11_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims189: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv238, axes=None)
            matmul189: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm95, permute_dims189, out_dtype="void")
            add141: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul189, model_layers_11_self_attn_c_attn_bias2)
            reshape188: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add141, R.shape([1, 1, 20, 128]))
            reshape189: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape188, R.shape([1, 20, 128]))
            lv239 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape189), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape190: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv239, R.shape([1, 1, 16, 128]))
            reshape191: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape190, R.shape([1, 1, 2048]))
            lv240 = R.call_tir(cls.dequantize2, (model_layers_11_self_attn_o_proj_q_weight2, model_layers_11_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims190: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv240, axes=None)
            matmul190: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape191, permute_dims190, out_dtype="void")
            add142: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul190, add140)
            rms_norm96: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add142, model_layers_11_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv241 = R.call_tir(cls.dequantize3, (model_layers_11_mlp_gate_up_proj_q_weight2, model_layers_11_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims191: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv241, axes=None)
            matmul191: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm96, permute_dims191, out_dtype="void")
            split47: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul191, indices_or_sections=2, axis=-1)
            split_047: R.Tensor((1, 1, 11008), dtype="float16") = split47[0]
            split_147: R.Tensor((1, 1, 11008), dtype="float16") = split47[1]
            silu47: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_047)
            mul47: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu47, split_147)
            lv242 = R.call_tir(cls.dequantize4, (model_layers_11_mlp_down_proj_q_weight2, model_layers_11_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims192: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv242, axes=None)
            matmul192: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul47, permute_dims192, out_dtype="void")
            add143: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul192, add142)
            rms_norm97: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add143, model_layers_12_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv243 = R.call_tir(cls.dequantize1, (model_layers_12_self_attn_c_attn_q_weight2, model_layers_12_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims193: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv243, axes=None)
            matmul193: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm97, permute_dims193, out_dtype="void")
            add144: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul193, model_layers_12_self_attn_c_attn_bias2)
            reshape192: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add144, R.shape([1, 1, 20, 128]))
            reshape193: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape192, R.shape([1, 20, 128]))
            lv244 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape193), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape194: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv244, R.shape([1, 1, 16, 128]))
            reshape195: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape194, R.shape([1, 1, 2048]))
            lv245 = R.call_tir(cls.dequantize2, (model_layers_12_self_attn_o_proj_q_weight2, model_layers_12_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims194: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv245, axes=None)
            matmul194: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape195, permute_dims194, out_dtype="void")
            add145: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul194, add143)
            rms_norm98: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add145, model_layers_12_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv246 = R.call_tir(cls.dequantize3, (model_layers_12_mlp_gate_up_proj_q_weight2, model_layers_12_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims195: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv246, axes=None)
            matmul195: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm98, permute_dims195, out_dtype="void")
            split48: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul195, indices_or_sections=2, axis=-1)
            split_048: R.Tensor((1, 1, 11008), dtype="float16") = split48[0]
            split_148: R.Tensor((1, 1, 11008), dtype="float16") = split48[1]
            silu48: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_048)
            mul48: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu48, split_148)
            lv247 = R.call_tir(cls.dequantize4, (model_layers_12_mlp_down_proj_q_weight2, model_layers_12_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims196: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv247, axes=None)
            matmul196: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul48, permute_dims196, out_dtype="void")
            add146: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul196, add145)
            rms_norm99: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add146, model_layers_13_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv248 = R.call_tir(cls.dequantize1, (model_layers_13_self_attn_c_attn_q_weight2, model_layers_13_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims197: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv248, axes=None)
            matmul197: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm99, permute_dims197, out_dtype="void")
            add147: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul197, model_layers_13_self_attn_c_attn_bias2)
            reshape196: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add147, R.shape([1, 1, 20, 128]))
            reshape197: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape196, R.shape([1, 20, 128]))
            lv249 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape197), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape198: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv249, R.shape([1, 1, 16, 128]))
            reshape199: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape198, R.shape([1, 1, 2048]))
            lv250 = R.call_tir(cls.dequantize2, (model_layers_13_self_attn_o_proj_q_weight2, model_layers_13_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims198: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv250, axes=None)
            matmul198: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape199, permute_dims198, out_dtype="void")
            add148: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul198, add146)
            rms_norm100: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add148, model_layers_13_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv251 = R.call_tir(cls.dequantize3, (model_layers_13_mlp_gate_up_proj_q_weight2, model_layers_13_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims199: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv251, axes=None)
            matmul199: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm100, permute_dims199, out_dtype="void")
            split49: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul199, indices_or_sections=2, axis=-1)
            split_049: R.Tensor((1, 1, 11008), dtype="float16") = split49[0]
            split_149: R.Tensor((1, 1, 11008), dtype="float16") = split49[1]
            silu49: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_049)
            mul49: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu49, split_149)
            lv252 = R.call_tir(cls.dequantize4, (model_layers_13_mlp_down_proj_q_weight2, model_layers_13_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims200: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv252, axes=None)
            matmul200: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul49, permute_dims200, out_dtype="void")
            add149: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul200, add148)
            rms_norm101: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add149, model_layers_14_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv253 = R.call_tir(cls.dequantize1, (model_layers_14_self_attn_c_attn_q_weight2, model_layers_14_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims201: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv253, axes=None)
            matmul201: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm101, permute_dims201, out_dtype="void")
            add150: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul201, model_layers_14_self_attn_c_attn_bias2)
            reshape200: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add150, R.shape([1, 1, 20, 128]))
            reshape201: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape200, R.shape([1, 20, 128]))
            lv254 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape201), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape202: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv254, R.shape([1, 1, 16, 128]))
            reshape203: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape202, R.shape([1, 1, 2048]))
            lv255 = R.call_tir(cls.dequantize2, (model_layers_14_self_attn_o_proj_q_weight2, model_layers_14_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims202: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv255, axes=None)
            matmul202: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape203, permute_dims202, out_dtype="void")
            add151: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul202, add149)
            rms_norm102: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add151, model_layers_14_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv256 = R.call_tir(cls.dequantize3, (model_layers_14_mlp_gate_up_proj_q_weight2, model_layers_14_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims203: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv256, axes=None)
            matmul203: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm102, permute_dims203, out_dtype="void")
            split50: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul203, indices_or_sections=2, axis=-1)
            split_050: R.Tensor((1, 1, 11008), dtype="float16") = split50[0]
            split_150: R.Tensor((1, 1, 11008), dtype="float16") = split50[1]
            silu50: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_050)
            mul50: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu50, split_150)
            lv257 = R.call_tir(cls.dequantize4, (model_layers_14_mlp_down_proj_q_weight2, model_layers_14_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims204: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv257, axes=None)
            matmul204: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul50, permute_dims204, out_dtype="void")
            add152: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul204, add151)
            rms_norm103: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add152, model_layers_15_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv258 = R.call_tir(cls.dequantize1, (model_layers_15_self_attn_c_attn_q_weight2, model_layers_15_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims205: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv258, axes=None)
            matmul205: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm103, permute_dims205, out_dtype="void")
            add153: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul205, model_layers_15_self_attn_c_attn_bias2)
            reshape204: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add153, R.shape([1, 1, 20, 128]))
            reshape205: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape204, R.shape([1, 20, 128]))
            lv259 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape205), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape206: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv259, R.shape([1, 1, 16, 128]))
            reshape207: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape206, R.shape([1, 1, 2048]))
            lv260 = R.call_tir(cls.dequantize2, (model_layers_15_self_attn_o_proj_q_weight2, model_layers_15_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims206: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv260, axes=None)
            matmul206: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape207, permute_dims206, out_dtype="void")
            add154: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul206, add152)
            rms_norm104: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add154, model_layers_15_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv261 = R.call_tir(cls.dequantize3, (model_layers_15_mlp_gate_up_proj_q_weight2, model_layers_15_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims207: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv261, axes=None)
            matmul207: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm104, permute_dims207, out_dtype="void")
            split51: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul207, indices_or_sections=2, axis=-1)
            split_051: R.Tensor((1, 1, 11008), dtype="float16") = split51[0]
            split_151: R.Tensor((1, 1, 11008), dtype="float16") = split51[1]
            silu51: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_051)
            mul51: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu51, split_151)
            lv262 = R.call_tir(cls.dequantize4, (model_layers_15_mlp_down_proj_q_weight2, model_layers_15_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims208: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv262, axes=None)
            matmul208: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul51, permute_dims208, out_dtype="void")
            add155: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul208, add154)
            rms_norm105: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add155, model_layers_16_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv263 = R.call_tir(cls.dequantize1, (model_layers_16_self_attn_c_attn_q_weight2, model_layers_16_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims209: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv263, axes=None)
            matmul209: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm105, permute_dims209, out_dtype="void")
            add156: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul209, model_layers_16_self_attn_c_attn_bias2)
            reshape208: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add156, R.shape([1, 1, 20, 128]))
            reshape209: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape208, R.shape([1, 20, 128]))
            lv264 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape209), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape210: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv264, R.shape([1, 1, 16, 128]))
            reshape211: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape210, R.shape([1, 1, 2048]))
            lv265 = R.call_tir(cls.dequantize2, (model_layers_16_self_attn_o_proj_q_weight2, model_layers_16_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims210: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv265, axes=None)
            matmul210: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape211, permute_dims210, out_dtype="void")
            add157: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul210, add155)
            rms_norm106: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add157, model_layers_16_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv266 = R.call_tir(cls.dequantize3, (model_layers_16_mlp_gate_up_proj_q_weight2, model_layers_16_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims211: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv266, axes=None)
            matmul211: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm106, permute_dims211, out_dtype="void")
            split52: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul211, indices_or_sections=2, axis=-1)
            split_052: R.Tensor((1, 1, 11008), dtype="float16") = split52[0]
            split_152: R.Tensor((1, 1, 11008), dtype="float16") = split52[1]
            silu52: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_052)
            mul52: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu52, split_152)
            lv267 = R.call_tir(cls.dequantize4, (model_layers_16_mlp_down_proj_q_weight2, model_layers_16_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims212: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv267, axes=None)
            matmul212: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul52, permute_dims212, out_dtype="void")
            add158: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul212, add157)
            rms_norm107: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add158, model_layers_17_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv268 = R.call_tir(cls.dequantize1, (model_layers_17_self_attn_c_attn_q_weight2, model_layers_17_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims213: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv268, axes=None)
            matmul213: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm107, permute_dims213, out_dtype="void")
            add159: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul213, model_layers_17_self_attn_c_attn_bias2)
            reshape212: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add159, R.shape([1, 1, 20, 128]))
            reshape213: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape212, R.shape([1, 20, 128]))
            lv269 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape213), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape214: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv269, R.shape([1, 1, 16, 128]))
            reshape215: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape214, R.shape([1, 1, 2048]))
            lv270 = R.call_tir(cls.dequantize2, (model_layers_17_self_attn_o_proj_q_weight2, model_layers_17_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims214: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv270, axes=None)
            matmul214: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape215, permute_dims214, out_dtype="void")
            add160: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul214, add158)
            rms_norm108: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add160, model_layers_17_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv271 = R.call_tir(cls.dequantize3, (model_layers_17_mlp_gate_up_proj_q_weight2, model_layers_17_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims215: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv271, axes=None)
            matmul215: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm108, permute_dims215, out_dtype="void")
            split53: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul215, indices_or_sections=2, axis=-1)
            split_053: R.Tensor((1, 1, 11008), dtype="float16") = split53[0]
            split_153: R.Tensor((1, 1, 11008), dtype="float16") = split53[1]
            silu53: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_053)
            mul53: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu53, split_153)
            lv272 = R.call_tir(cls.dequantize4, (model_layers_17_mlp_down_proj_q_weight2, model_layers_17_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims216: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv272, axes=None)
            matmul216: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul53, permute_dims216, out_dtype="void")
            add161: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul216, add160)
            rms_norm109: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add161, model_layers_18_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv273 = R.call_tir(cls.dequantize1, (model_layers_18_self_attn_c_attn_q_weight2, model_layers_18_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims217: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv273, axes=None)
            matmul217: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm109, permute_dims217, out_dtype="void")
            add162: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul217, model_layers_18_self_attn_c_attn_bias2)
            reshape216: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add162, R.shape([1, 1, 20, 128]))
            reshape217: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape216, R.shape([1, 20, 128]))
            lv274 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape217), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape218: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv274, R.shape([1, 1, 16, 128]))
            reshape219: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape218, R.shape([1, 1, 2048]))
            lv275 = R.call_tir(cls.dequantize2, (model_layers_18_self_attn_o_proj_q_weight2, model_layers_18_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims218: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv275, axes=None)
            matmul218: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape219, permute_dims218, out_dtype="void")
            add163: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul218, add161)
            rms_norm110: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add163, model_layers_18_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv276 = R.call_tir(cls.dequantize3, (model_layers_18_mlp_gate_up_proj_q_weight2, model_layers_18_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims219: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv276, axes=None)
            matmul219: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm110, permute_dims219, out_dtype="void")
            split54: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul219, indices_or_sections=2, axis=-1)
            split_054: R.Tensor((1, 1, 11008), dtype="float16") = split54[0]
            split_154: R.Tensor((1, 1, 11008), dtype="float16") = split54[1]
            silu54: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_054)
            mul54: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu54, split_154)
            lv277 = R.call_tir(cls.dequantize4, (model_layers_18_mlp_down_proj_q_weight2, model_layers_18_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims220: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv277, axes=None)
            matmul220: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul54, permute_dims220, out_dtype="void")
            add164: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul220, add163)
            rms_norm111: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add164, model_layers_19_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv278 = R.call_tir(cls.dequantize1, (model_layers_19_self_attn_c_attn_q_weight2, model_layers_19_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims221: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv278, axes=None)
            matmul221: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm111, permute_dims221, out_dtype="void")
            add165: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul221, model_layers_19_self_attn_c_attn_bias2)
            reshape220: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add165, R.shape([1, 1, 20, 128]))
            reshape221: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape220, R.shape([1, 20, 128]))
            lv279 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape221), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape222: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv279, R.shape([1, 1, 16, 128]))
            reshape223: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape222, R.shape([1, 1, 2048]))
            lv280 = R.call_tir(cls.dequantize2, (model_layers_19_self_attn_o_proj_q_weight2, model_layers_19_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims222: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv280, axes=None)
            matmul222: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape223, permute_dims222, out_dtype="void")
            add166: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul222, add164)
            rms_norm112: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add166, model_layers_19_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv281 = R.call_tir(cls.dequantize3, (model_layers_19_mlp_gate_up_proj_q_weight2, model_layers_19_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims223: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv281, axes=None)
            matmul223: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm112, permute_dims223, out_dtype="void")
            split55: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul223, indices_or_sections=2, axis=-1)
            split_055: R.Tensor((1, 1, 11008), dtype="float16") = split55[0]
            split_155: R.Tensor((1, 1, 11008), dtype="float16") = split55[1]
            silu55: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_055)
            mul55: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu55, split_155)
            lv282 = R.call_tir(cls.dequantize4, (model_layers_19_mlp_down_proj_q_weight2, model_layers_19_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims224: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv282, axes=None)
            matmul224: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul55, permute_dims224, out_dtype="void")
            add167: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul224, add166)
            rms_norm113: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add167, model_layers_20_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv283 = R.call_tir(cls.dequantize1, (model_layers_20_self_attn_c_attn_q_weight2, model_layers_20_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims225: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv283, axes=None)
            matmul225: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm113, permute_dims225, out_dtype="void")
            add168: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul225, model_layers_20_self_attn_c_attn_bias2)
            reshape224: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add168, R.shape([1, 1, 20, 128]))
            reshape225: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape224, R.shape([1, 20, 128]))
            lv284 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape225), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape226: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv284, R.shape([1, 1, 16, 128]))
            reshape227: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape226, R.shape([1, 1, 2048]))
            lv285 = R.call_tir(cls.dequantize2, (model_layers_20_self_attn_o_proj_q_weight2, model_layers_20_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims226: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv285, axes=None)
            matmul226: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape227, permute_dims226, out_dtype="void")
            add169: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul226, add167)
            rms_norm114: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add169, model_layers_20_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv286 = R.call_tir(cls.dequantize3, (model_layers_20_mlp_gate_up_proj_q_weight2, model_layers_20_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims227: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv286, axes=None)
            matmul227: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm114, permute_dims227, out_dtype="void")
            split56: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul227, indices_or_sections=2, axis=-1)
            split_056: R.Tensor((1, 1, 11008), dtype="float16") = split56[0]
            split_156: R.Tensor((1, 1, 11008), dtype="float16") = split56[1]
            silu56: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_056)
            mul56: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu56, split_156)
            lv287 = R.call_tir(cls.dequantize4, (model_layers_20_mlp_down_proj_q_weight2, model_layers_20_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims228: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv287, axes=None)
            matmul228: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul56, permute_dims228, out_dtype="void")
            add170: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul228, add169)
            rms_norm115: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add170, model_layers_21_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv288 = R.call_tir(cls.dequantize1, (model_layers_21_self_attn_c_attn_q_weight2, model_layers_21_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims229: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv288, axes=None)
            matmul229: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm115, permute_dims229, out_dtype="void")
            add171: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul229, model_layers_21_self_attn_c_attn_bias2)
            reshape228: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add171, R.shape([1, 1, 20, 128]))
            reshape229: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape228, R.shape([1, 20, 128]))
            lv289 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape229), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape230: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv289, R.shape([1, 1, 16, 128]))
            reshape231: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape230, R.shape([1, 1, 2048]))
            lv290 = R.call_tir(cls.dequantize2, (model_layers_21_self_attn_o_proj_q_weight2, model_layers_21_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims230: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv290, axes=None)
            matmul230: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape231, permute_dims230, out_dtype="void")
            add172: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul230, add170)
            rms_norm116: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add172, model_layers_21_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv291 = R.call_tir(cls.dequantize3, (model_layers_21_mlp_gate_up_proj_q_weight2, model_layers_21_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims231: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv291, axes=None)
            matmul231: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm116, permute_dims231, out_dtype="void")
            split57: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul231, indices_or_sections=2, axis=-1)
            split_057: R.Tensor((1, 1, 11008), dtype="float16") = split57[0]
            split_157: R.Tensor((1, 1, 11008), dtype="float16") = split57[1]
            silu57: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_057)
            mul57: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu57, split_157)
            lv292 = R.call_tir(cls.dequantize4, (model_layers_21_mlp_down_proj_q_weight2, model_layers_21_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims232: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv292, axes=None)
            matmul232: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul57, permute_dims232, out_dtype="void")
            add173: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul232, add172)
            rms_norm117: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add173, model_layers_22_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv293 = R.call_tir(cls.dequantize1, (model_layers_22_self_attn_c_attn_q_weight2, model_layers_22_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims233: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv293, axes=None)
            matmul233: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm117, permute_dims233, out_dtype="void")
            add174: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul233, model_layers_22_self_attn_c_attn_bias2)
            reshape232: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add174, R.shape([1, 1, 20, 128]))
            reshape233: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape232, R.shape([1, 20, 128]))
            lv294 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape233), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape234: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv294, R.shape([1, 1, 16, 128]))
            reshape235: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape234, R.shape([1, 1, 2048]))
            lv295 = R.call_tir(cls.dequantize2, (model_layers_22_self_attn_o_proj_q_weight2, model_layers_22_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims234: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv295, axes=None)
            matmul234: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape235, permute_dims234, out_dtype="void")
            add175: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul234, add173)
            rms_norm118: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add175, model_layers_22_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv296 = R.call_tir(cls.dequantize3, (model_layers_22_mlp_gate_up_proj_q_weight2, model_layers_22_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims235: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv296, axes=None)
            matmul235: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm118, permute_dims235, out_dtype="void")
            split58: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul235, indices_or_sections=2, axis=-1)
            split_058: R.Tensor((1, 1, 11008), dtype="float16") = split58[0]
            split_158: R.Tensor((1, 1, 11008), dtype="float16") = split58[1]
            silu58: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_058)
            mul58: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu58, split_158)
            lv297 = R.call_tir(cls.dequantize4, (model_layers_22_mlp_down_proj_q_weight2, model_layers_22_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims236: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv297, axes=None)
            matmul236: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul58, permute_dims236, out_dtype="void")
            add176: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul236, add175)
            rms_norm119: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add176, model_layers_23_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv298 = R.call_tir(cls.dequantize1, (model_layers_23_self_attn_c_attn_q_weight2, model_layers_23_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims237: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv298, axes=None)
            matmul237: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm119, permute_dims237, out_dtype="void")
            add177: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul237, model_layers_23_self_attn_c_attn_bias2)
            reshape236: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add177, R.shape([1, 1, 20, 128]))
            reshape237: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape236, R.shape([1, 20, 128]))
            lv299 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape237), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape238: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv299, R.shape([1, 1, 16, 128]))
            reshape239: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape238, R.shape([1, 1, 2048]))
            lv300 = R.call_tir(cls.dequantize2, (model_layers_23_self_attn_o_proj_q_weight2, model_layers_23_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims238: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv300, axes=None)
            matmul238: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape239, permute_dims238, out_dtype="void")
            add178: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul238, add176)
            rms_norm120: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add178, model_layers_23_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv301 = R.call_tir(cls.dequantize3, (model_layers_23_mlp_gate_up_proj_q_weight2, model_layers_23_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims239: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv301, axes=None)
            matmul239: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm120, permute_dims239, out_dtype="void")
            split59: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul239, indices_or_sections=2, axis=-1)
            split_059: R.Tensor((1, 1, 11008), dtype="float16") = split59[0]
            split_159: R.Tensor((1, 1, 11008), dtype="float16") = split59[1]
            silu59: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_059)
            mul59: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu59, split_159)
            lv302 = R.call_tir(cls.dequantize4, (model_layers_23_mlp_down_proj_q_weight2, model_layers_23_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims240: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv302, axes=None)
            matmul240: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul59, permute_dims240, out_dtype="void")
            add179: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul240, add178)
            rms_norm121: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add179, model_layers_24_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv303 = R.call_tir(cls.dequantize1, (model_layers_24_self_attn_c_attn_q_weight2, model_layers_24_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims241: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv303, axes=None)
            matmul241: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm121, permute_dims241, out_dtype="void")
            add180: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul241, model_layers_24_self_attn_c_attn_bias2)
            reshape240: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add180, R.shape([1, 1, 20, 128]))
            reshape241: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape240, R.shape([1, 20, 128]))
            lv304 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape241), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape242: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv304, R.shape([1, 1, 16, 128]))
            reshape243: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape242, R.shape([1, 1, 2048]))
            lv305 = R.call_tir(cls.dequantize2, (model_layers_24_self_attn_o_proj_q_weight2, model_layers_24_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims242: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv305, axes=None)
            matmul242: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape243, permute_dims242, out_dtype="void")
            add181: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul242, add179)
            rms_norm122: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add181, model_layers_24_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv306 = R.call_tir(cls.dequantize3, (model_layers_24_mlp_gate_up_proj_q_weight2, model_layers_24_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims243: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv306, axes=None)
            matmul243: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm122, permute_dims243, out_dtype="void")
            split60: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul243, indices_or_sections=2, axis=-1)
            split_060: R.Tensor((1, 1, 11008), dtype="float16") = split60[0]
            split_160: R.Tensor((1, 1, 11008), dtype="float16") = split60[1]
            silu60: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_060)
            mul60: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu60, split_160)
            lv307 = R.call_tir(cls.dequantize4, (model_layers_24_mlp_down_proj_q_weight2, model_layers_24_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims244: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv307, axes=None)
            matmul244: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul60, permute_dims244, out_dtype="void")
            add182: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul244, add181)
            rms_norm123: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add182, model_layers_25_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv308 = R.call_tir(cls.dequantize1, (model_layers_25_self_attn_c_attn_q_weight2, model_layers_25_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims245: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv308, axes=None)
            matmul245: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm123, permute_dims245, out_dtype="void")
            add183: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul245, model_layers_25_self_attn_c_attn_bias2)
            reshape244: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add183, R.shape([1, 1, 20, 128]))
            reshape245: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape244, R.shape([1, 20, 128]))
            lv309 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape245), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape246: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv309, R.shape([1, 1, 16, 128]))
            reshape247: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape246, R.shape([1, 1, 2048]))
            lv310 = R.call_tir(cls.dequantize2, (model_layers_25_self_attn_o_proj_q_weight2, model_layers_25_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims246: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv310, axes=None)
            matmul246: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape247, permute_dims246, out_dtype="void")
            add184: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul246, add182)
            rms_norm124: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add184, model_layers_25_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv311 = R.call_tir(cls.dequantize3, (model_layers_25_mlp_gate_up_proj_q_weight2, model_layers_25_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims247: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv311, axes=None)
            matmul247: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm124, permute_dims247, out_dtype="void")
            split61: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul247, indices_or_sections=2, axis=-1)
            split_061: R.Tensor((1, 1, 11008), dtype="float16") = split61[0]
            split_161: R.Tensor((1, 1, 11008), dtype="float16") = split61[1]
            silu61: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_061)
            mul61: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu61, split_161)
            lv312 = R.call_tir(cls.dequantize4, (model_layers_25_mlp_down_proj_q_weight2, model_layers_25_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims248: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv312, axes=None)
            matmul248: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul61, permute_dims248, out_dtype="void")
            add185: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul248, add184)
            rms_norm125: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add185, model_layers_26_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv313 = R.call_tir(cls.dequantize1, (model_layers_26_self_attn_c_attn_q_weight2, model_layers_26_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims249: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv313, axes=None)
            matmul249: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm125, permute_dims249, out_dtype="void")
            add186: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul249, model_layers_26_self_attn_c_attn_bias2)
            reshape248: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add186, R.shape([1, 1, 20, 128]))
            reshape249: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape248, R.shape([1, 20, 128]))
            lv314 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape249), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape250: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv314, R.shape([1, 1, 16, 128]))
            reshape251: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape250, R.shape([1, 1, 2048]))
            lv315 = R.call_tir(cls.dequantize2, (model_layers_26_self_attn_o_proj_q_weight2, model_layers_26_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims250: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv315, axes=None)
            matmul250: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape251, permute_dims250, out_dtype="void")
            add187: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul250, add185)
            rms_norm126: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add187, model_layers_26_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv316 = R.call_tir(cls.dequantize3, (model_layers_26_mlp_gate_up_proj_q_weight2, model_layers_26_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims251: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv316, axes=None)
            matmul251: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm126, permute_dims251, out_dtype="void")
            split62: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul251, indices_or_sections=2, axis=-1)
            split_062: R.Tensor((1, 1, 11008), dtype="float16") = split62[0]
            split_162: R.Tensor((1, 1, 11008), dtype="float16") = split62[1]
            silu62: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_062)
            mul62: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu62, split_162)
            lv317 = R.call_tir(cls.dequantize4, (model_layers_26_mlp_down_proj_q_weight2, model_layers_26_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims252: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv317, axes=None)
            matmul252: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul62, permute_dims252, out_dtype="void")
            add188: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul252, add187)
            rms_norm127: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add188, model_layers_27_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv318 = R.call_tir(cls.dequantize1, (model_layers_27_self_attn_c_attn_q_weight2, model_layers_27_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims253: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv318, axes=None)
            matmul253: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm127, permute_dims253, out_dtype="void")
            add189: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul253, model_layers_27_self_attn_c_attn_bias2)
            reshape252: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add189, R.shape([1, 1, 20, 128]))
            reshape253: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape252, R.shape([1, 20, 128]))
            lv319 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape253), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape254: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv319, R.shape([1, 1, 16, 128]))
            reshape255: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape254, R.shape([1, 1, 2048]))
            lv320 = R.call_tir(cls.dequantize2, (model_layers_27_self_attn_o_proj_q_weight2, model_layers_27_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims254: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv320, axes=None)
            matmul254: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape255, permute_dims254, out_dtype="void")
            add190: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul254, add188)
            rms_norm128: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add190, model_layers_27_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv321 = R.call_tir(cls.dequantize3, (model_layers_27_mlp_gate_up_proj_q_weight2, model_layers_27_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims255: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv321, axes=None)
            matmul255: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm128, permute_dims255, out_dtype="void")
            split63: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul255, indices_or_sections=2, axis=-1)
            split_063: R.Tensor((1, 1, 11008), dtype="float16") = split63[0]
            split_163: R.Tensor((1, 1, 11008), dtype="float16") = split63[1]
            silu63: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_063)
            mul63: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu63, split_163)
            lv322 = R.call_tir(cls.dequantize4, (model_layers_27_mlp_down_proj_q_weight2, model_layers_27_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims256: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv322, axes=None)
            matmul256: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul63, permute_dims256, out_dtype="void")
            add191: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul256, add190)
            rms_norm129: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add191, model_layers_28_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv323 = R.call_tir(cls.dequantize1, (model_layers_28_self_attn_c_attn_q_weight2, model_layers_28_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims257: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv323, axes=None)
            matmul257: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm129, permute_dims257, out_dtype="void")
            add192: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul257, model_layers_28_self_attn_c_attn_bias2)
            reshape256: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add192, R.shape([1, 1, 20, 128]))
            reshape257: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape256, R.shape([1, 20, 128]))
            lv324 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape257), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape258: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv324, R.shape([1, 1, 16, 128]))
            reshape259: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape258, R.shape([1, 1, 2048]))
            lv325 = R.call_tir(cls.dequantize2, (model_layers_28_self_attn_o_proj_q_weight2, model_layers_28_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims258: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv325, axes=None)
            matmul258: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape259, permute_dims258, out_dtype="void")
            add193: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul258, add191)
            rms_norm130: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add193, model_layers_28_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv326 = R.call_tir(cls.dequantize3, (model_layers_28_mlp_gate_up_proj_q_weight2, model_layers_28_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims259: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv326, axes=None)
            matmul259: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm130, permute_dims259, out_dtype="void")
            split64: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul259, indices_or_sections=2, axis=-1)
            split_064: R.Tensor((1, 1, 11008), dtype="float16") = split64[0]
            split_164: R.Tensor((1, 1, 11008), dtype="float16") = split64[1]
            silu64: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_064)
            mul64: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu64, split_164)
            lv327 = R.call_tir(cls.dequantize4, (model_layers_28_mlp_down_proj_q_weight2, model_layers_28_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims260: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv327, axes=None)
            matmul260: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul64, permute_dims260, out_dtype="void")
            add194: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul260, add193)
            rms_norm131: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add194, model_layers_29_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv328 = R.call_tir(cls.dequantize1, (model_layers_29_self_attn_c_attn_q_weight2, model_layers_29_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims261: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv328, axes=None)
            matmul261: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm131, permute_dims261, out_dtype="void")
            add195: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul261, model_layers_29_self_attn_c_attn_bias2)
            reshape260: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add195, R.shape([1, 1, 20, 128]))
            reshape261: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape260, R.shape([1, 20, 128]))
            lv329 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape261), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape262: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv329, R.shape([1, 1, 16, 128]))
            reshape263: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape262, R.shape([1, 1, 2048]))
            lv330 = R.call_tir(cls.dequantize2, (model_layers_29_self_attn_o_proj_q_weight2, model_layers_29_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims262: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv330, axes=None)
            matmul262: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape263, permute_dims262, out_dtype="void")
            add196: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul262, add194)
            rms_norm132: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add196, model_layers_29_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv331 = R.call_tir(cls.dequantize3, (model_layers_29_mlp_gate_up_proj_q_weight2, model_layers_29_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims263: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv331, axes=None)
            matmul263: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm132, permute_dims263, out_dtype="void")
            split65: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul263, indices_or_sections=2, axis=-1)
            split_065: R.Tensor((1, 1, 11008), dtype="float16") = split65[0]
            split_165: R.Tensor((1, 1, 11008), dtype="float16") = split65[1]
            silu65: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_065)
            mul65: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu65, split_165)
            lv332 = R.call_tir(cls.dequantize4, (model_layers_29_mlp_down_proj_q_weight2, model_layers_29_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims264: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv332, axes=None)
            matmul264: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul65, permute_dims264, out_dtype="void")
            add197: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul264, add196)
            rms_norm133: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add197, model_layers_30_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv333 = R.call_tir(cls.dequantize1, (model_layers_30_self_attn_c_attn_q_weight2, model_layers_30_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims265: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv333, axes=None)
            matmul265: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm133, permute_dims265, out_dtype="void")
            add198: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul265, model_layers_30_self_attn_c_attn_bias2)
            reshape264: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add198, R.shape([1, 1, 20, 128]))
            reshape265: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape264, R.shape([1, 20, 128]))
            lv334 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape265), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape266: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv334, R.shape([1, 1, 16, 128]))
            reshape267: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape266, R.shape([1, 1, 2048]))
            lv335 = R.call_tir(cls.dequantize2, (model_layers_30_self_attn_o_proj_q_weight2, model_layers_30_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims266: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv335, axes=None)
            matmul266: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape267, permute_dims266, out_dtype="void")
            add199: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul266, add197)
            rms_norm134: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add199, model_layers_30_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv336 = R.call_tir(cls.dequantize3, (model_layers_30_mlp_gate_up_proj_q_weight2, model_layers_30_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims267: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv336, axes=None)
            matmul267: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm134, permute_dims267, out_dtype="void")
            split66: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul267, indices_or_sections=2, axis=-1)
            split_066: R.Tensor((1, 1, 11008), dtype="float16") = split66[0]
            split_166: R.Tensor((1, 1, 11008), dtype="float16") = split66[1]
            silu66: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_066)
            mul66: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu66, split_166)
            lv337 = R.call_tir(cls.dequantize4, (model_layers_30_mlp_down_proj_q_weight2, model_layers_30_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims268: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv337, axes=None)
            matmul268: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul66, permute_dims268, out_dtype="void")
            add200: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul268, add199)
            rms_norm135: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add200, model_layers_31_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv338 = R.call_tir(cls.dequantize1, (model_layers_31_self_attn_c_attn_q_weight2, model_layers_31_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims269: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv338, axes=None)
            matmul269: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm135, permute_dims269, out_dtype="void")
            add201: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul269, model_layers_31_self_attn_c_attn_bias2)
            reshape268: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add201, R.shape([1, 1, 20, 128]))
            reshape269: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape268, R.shape([1, 20, 128]))
            lv339 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape269), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape270: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv339, R.shape([1, 1, 16, 128]))
            reshape271: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape270, R.shape([1, 1, 2048]))
            lv340 = R.call_tir(cls.dequantize2, (model_layers_31_self_attn_o_proj_q_weight2, model_layers_31_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims270: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv340, axes=None)
            matmul270: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape271, permute_dims270, out_dtype="void")
            add202: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul270, add200)
            rms_norm136: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add202, model_layers_31_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv341 = R.call_tir(cls.dequantize3, (model_layers_31_mlp_gate_up_proj_q_weight2, model_layers_31_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims271: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv341, axes=None)
            matmul271: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm136, permute_dims271, out_dtype="void")
            split67: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul271, indices_or_sections=2, axis=-1)
            split_067: R.Tensor((1, 1, 11008), dtype="float16") = split67[0]
            split_167: R.Tensor((1, 1, 11008), dtype="float16") = split67[1]
            silu67: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_067)
            mul67: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu67, split_167)
            lv342 = R.call_tir(cls.dequantize4, (model_layers_31_mlp_down_proj_q_weight2, model_layers_31_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims272: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv342, axes=None)
            matmul272: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul67, permute_dims272, out_dtype="void")
            add203: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul272, add202)
            rms_norm137: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add203, model_layers_32_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv343 = R.call_tir(cls.dequantize1, (model_layers_32_self_attn_c_attn_q_weight2, model_layers_32_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims273: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv343, axes=None)
            matmul273: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm137, permute_dims273, out_dtype="void")
            add204: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul273, model_layers_32_self_attn_c_attn_bias2)
            reshape272: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add204, R.shape([1, 1, 20, 128]))
            reshape273: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape272, R.shape([1, 20, 128]))
            lv344 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape273), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape274: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv344, R.shape([1, 1, 16, 128]))
            reshape275: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape274, R.shape([1, 1, 2048]))
            lv345 = R.call_tir(cls.dequantize2, (model_layers_32_self_attn_o_proj_q_weight2, model_layers_32_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims274: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv345, axes=None)
            matmul274: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape275, permute_dims274, out_dtype="void")
            add205: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul274, add203)
            rms_norm138: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add205, model_layers_32_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv346 = R.call_tir(cls.dequantize3, (model_layers_32_mlp_gate_up_proj_q_weight2, model_layers_32_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims275: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv346, axes=None)
            matmul275: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm138, permute_dims275, out_dtype="void")
            split68: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul275, indices_or_sections=2, axis=-1)
            split_068: R.Tensor((1, 1, 11008), dtype="float16") = split68[0]
            split_168: R.Tensor((1, 1, 11008), dtype="float16") = split68[1]
            silu68: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_068)
            mul68: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu68, split_168)
            lv347 = R.call_tir(cls.dequantize4, (model_layers_32_mlp_down_proj_q_weight2, model_layers_32_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims276: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv347, axes=None)
            matmul276: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul68, permute_dims276, out_dtype="void")
            add206: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul276, add205)
            rms_norm139: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add206, model_layers_33_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv348 = R.call_tir(cls.dequantize1, (model_layers_33_self_attn_c_attn_q_weight2, model_layers_33_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims277: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv348, axes=None)
            matmul277: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm139, permute_dims277, out_dtype="void")
            add207: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul277, model_layers_33_self_attn_c_attn_bias2)
            reshape276: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add207, R.shape([1, 1, 20, 128]))
            reshape277: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape276, R.shape([1, 20, 128]))
            lv349 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape277), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape278: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv349, R.shape([1, 1, 16, 128]))
            reshape279: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape278, R.shape([1, 1, 2048]))
            lv350 = R.call_tir(cls.dequantize2, (model_layers_33_self_attn_o_proj_q_weight2, model_layers_33_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims278: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv350, axes=None)
            matmul278: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape279, permute_dims278, out_dtype="void")
            add208: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul278, add206)
            rms_norm140: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add208, model_layers_33_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv351 = R.call_tir(cls.dequantize3, (model_layers_33_mlp_gate_up_proj_q_weight2, model_layers_33_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims279: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv351, axes=None)
            matmul279: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm140, permute_dims279, out_dtype="void")
            split69: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul279, indices_or_sections=2, axis=-1)
            split_069: R.Tensor((1, 1, 11008), dtype="float16") = split69[0]
            split_169: R.Tensor((1, 1, 11008), dtype="float16") = split69[1]
            silu69: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_069)
            mul69: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu69, split_169)
            lv352 = R.call_tir(cls.dequantize4, (model_layers_33_mlp_down_proj_q_weight2, model_layers_33_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims280: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv352, axes=None)
            matmul280: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul69, permute_dims280, out_dtype="void")
            add209: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul280, add208)
            rms_norm141: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add209, model_layers_34_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv353 = R.call_tir(cls.dequantize1, (model_layers_34_self_attn_c_attn_q_weight2, model_layers_34_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims281: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv353, axes=None)
            matmul281: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm141, permute_dims281, out_dtype="void")
            add210: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul281, model_layers_34_self_attn_c_attn_bias2)
            reshape280: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add210, R.shape([1, 1, 20, 128]))
            reshape281: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape280, R.shape([1, 20, 128]))
            lv354 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape281), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape282: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv354, R.shape([1, 1, 16, 128]))
            reshape283: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape282, R.shape([1, 1, 2048]))
            lv355 = R.call_tir(cls.dequantize2, (model_layers_34_self_attn_o_proj_q_weight2, model_layers_34_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims282: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv355, axes=None)
            matmul282: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape283, permute_dims282, out_dtype="void")
            add211: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul282, add209)
            rms_norm142: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add211, model_layers_34_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv356 = R.call_tir(cls.dequantize3, (model_layers_34_mlp_gate_up_proj_q_weight2, model_layers_34_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims283: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv356, axes=None)
            matmul283: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm142, permute_dims283, out_dtype="void")
            split70: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul283, indices_or_sections=2, axis=-1)
            split_070: R.Tensor((1, 1, 11008), dtype="float16") = split70[0]
            split_170: R.Tensor((1, 1, 11008), dtype="float16") = split70[1]
            silu70: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_070)
            mul70: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu70, split_170)
            lv357 = R.call_tir(cls.dequantize4, (model_layers_34_mlp_down_proj_q_weight2, model_layers_34_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims284: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv357, axes=None)
            matmul284: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul70, permute_dims284, out_dtype="void")
            add212: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul284, add211)
            rms_norm143: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add212, model_layers_35_input_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv358 = R.call_tir(cls.dequantize1, (model_layers_35_self_attn_c_attn_q_weight2, model_layers_35_self_attn_c_attn_q_scale2), out_sinfo=R.Tensor((2560, 2048), dtype="float16"))
            permute_dims285: R.Tensor((2048, 2560), dtype="float16") = R.permute_dims(lv358, axes=None)
            matmul285: R.Tensor((1, 1, 2560), dtype="float16") = R.matmul(rms_norm143, permute_dims285, out_dtype="void")
            add213: R.Tensor((1, 1, 2560), dtype="float16") = R.add(matmul285, model_layers_35_self_attn_c_attn_bias2)
            reshape284: R.Tensor((1, 1, 20, 128), dtype="float16") = R.reshape(add213, R.shape([1, 1, 20, 128]))
            reshape285: R.Tensor((1, 20, 128), dtype="float16") = R.reshape(reshape284, R.shape([1, 20, 128]))
            lv359 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape285), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            reshape286: R.Tensor((1, 1, 16, 128), dtype="float16") = R.reshape(lv359, R.shape([1, 1, 16, 128]))
            reshape287: R.Tensor((1, 1, 2048), dtype="float16") = R.reshape(reshape286, R.shape([1, 1, 2048]))
            lv360 = R.call_tir(cls.dequantize2, (model_layers_35_self_attn_o_proj_q_weight2, model_layers_35_self_attn_o_proj_q_scale2), out_sinfo=R.Tensor((2048, 2048), dtype="float16"))
            permute_dims286: R.Tensor((2048, 2048), dtype="float16") = R.permute_dims(lv360, axes=None)
            matmul286: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(reshape287, permute_dims286, out_dtype="void")
            add214: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul286, add212)
            rms_norm144: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add214, model_layers_35_post_attention_layernorm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv361 = R.call_tir(cls.dequantize3, (model_layers_35_mlp_gate_up_proj_q_weight2, model_layers_35_mlp_gate_up_proj_q_scale2), out_sinfo=R.Tensor((22016, 2048), dtype="float16"))
            permute_dims287: R.Tensor((2048, 22016), dtype="float16") = R.permute_dims(lv361, axes=None)
            matmul287: R.Tensor((1, 1, 22016), dtype="float16") = R.matmul(rms_norm144, permute_dims287, out_dtype="void")
            split71: R.Tuple(R.Tensor((1, 1, 11008), dtype="float16"), R.Tensor((1, 1, 11008), dtype="float16")) = R.split(matmul287, indices_or_sections=2, axis=-1)
            split_071: R.Tensor((1, 1, 11008), dtype="float16") = split71[0]
            split_171: R.Tensor((1, 1, 11008), dtype="float16") = split71[1]
            silu71: R.Tensor((1, 1, 11008), dtype="float16") = R.nn.silu(split_071)
            mul71: R.Tensor((1, 1, 11008), dtype="float16") = R.multiply(silu71, split_171)
            lv362 = R.call_tir(cls.dequantize4, (model_layers_35_mlp_down_proj_q_weight2, model_layers_35_mlp_down_proj_q_scale2), out_sinfo=R.Tensor((2048, 11008), dtype="float16"))
            permute_dims288: R.Tensor((11008, 2048), dtype="float16") = R.permute_dims(lv362, axes=None)
            matmul288: R.Tensor((1, 1, 2048), dtype="float16") = R.matmul(mul71, permute_dims288, out_dtype="void")
            add215: R.Tensor((1, 1, 2048), dtype="float16") = R.add(matmul288, add214)
            rms_norm145: R.Tensor((1, 1, 2048), dtype="float16") = R.nn.rms_norm(add215, model_norm_weight2, axes=[-1], epsilon=9.9999999999999995e-07)
            lv363 = R.call_tir(cls.dequantize, (model_embed_tokens_q_weight2, model_embed_tokens_q_scale2), out_sinfo=R.Tensor((151936, 2048), dtype="float16"))
            permute_dims289: R.Tensor((2048, 151936), dtype="float16") = R.permute_dims(lv363, axes=None)
            matmul289: R.Tensor((1, 1, 151936), dtype="float32") = R.matmul(rms_norm145, permute_dims289, out_dtype="float32")
            gv2: R.Tuple(R.Tensor((1, 1, 151936), dtype="float32"), R.Object) = matmul289, paged_kv_cache
            R.output(gv2)
        return gv2

    @R.function
    def embed(input_ids: R.Tensor(("seq_len",), dtype="int32"), packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,),