# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func(private=True)
    def argsort(var_probs: T.handle, var_argsort_gpu_v1: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size), offset_factor=1)
        out_buf = T.match_buffer(var_argsort_gpu_v1, (batch_size, vocab_size), "int32", align=8)
        # with T.block("root"):
        value_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        value_swap_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        out_swap_buf = T.alloc_buffer((batch_size, vocab_size), "int32", align=8)
        with T.block("argsort_gpu"):
            T.reads()
            T.writes()
            if vocab_size > T.int64(0):
                with T.launch_thread("threadIdx.x", T.int64(1024)) as threadIdx_x:
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(1023)) // T.int64(1024)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    if blockIdx_x * T.int64(1024) + threadIdx_x < vocab_size:
                        value_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = probs[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size]
                        out_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = T.Cast("int32", blockIdx_x * T.int64(1024) + threadIdx_x)
                with T.attr(0, "hand_threaded", 0):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(64))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(127)) // T.int64(128)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    temp_keys_swap = T.allocate([T.int64(128)], "float32", "shared")
                    temp_values_swap = T.allocate([T.int64(128)], "int32", "shared")
                    temp_keys = T.allocate([T.int64(1)], "float32", "local")
                    temp_values = T.allocate([T.int64(1)], "int32", "local")
                    temp_cond1 = T.allocate([T.int64(1)], "float32", "local")
                    temp_cond2 = T.allocate([T.int64(1)], "float32", "local")
                    temp_keys_swap_1 = T.Buffer((128,), data=temp_keys_swap, scope="shared")
                    temp_values_swap_1 = T.Buffer((128,), "int32", data=temp_values_swap, scope="shared")
                    for i in range(T.int64(2)):
                        if T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128) < vocab_size:
                            temp_keys_swap_1[T.int64(2) * threadIdx_x + i] = value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) % vocab_size]
                            temp_values_swap_1[T.int64(2) * threadIdx_x + i] = out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) % vocab_size]
                    T.tvm_storage_sync("shared")
                    for j in range(T.min(T.int64(128), vocab_size - blockIdx_x * T.int64(128))):
                        if T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) < T.min(T.int64(128), vocab_size - blockIdx_x * T.int64(128)) - T.int64(1):
                            temp_cond1_1 = T.Buffer((1,), data=temp_cond1, scope="local")
                            temp_cond1_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                            temp_cond2_1 = T.Buffer((1,), data=temp_cond2, scope="local")
                            temp_cond2_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                            if temp_cond1_1[T.int64(0)] < temp_cond2_1[T.int64(0)]:
                                temp_keys_1 = T.Buffer((1,), data=temp_keys, scope="local")
                                temp_keys_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                                temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                                temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)] = temp_keys_1[T.int64(0)]
                                temp_values_1 = T.Buffer((1,), "int32", data=temp_values, scope="local")
                                temp_values_1[T.int64(0)] = temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                                temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)] = temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                                temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)] = temp_values_1[T.int64(0)]
                        T.tvm_storage_sync("shared")
                    for k in range(T.int64(2)):
                        if T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128) < vocab_size:
                            value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_keys_swap_1[T.int64(2) * threadIdx_x + k]
                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_keys_swap_1[T.int64(2) * threadIdx_x + k]
                            out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_values_swap_1[T.int64(2) * threadIdx_x + k]
                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_values_swap_1[T.int64(2) * threadIdx_x + k]
                for i_0 in range(T.Cast("int64", T.ceil(T.log2(T.Cast("float64", vocab_size)))) - T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.max(T.int64(1), T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + T.int64(4095)) // T.int64(4096)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size * ((vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) - T.int64(1))) // T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))))
                    if T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) < vocab_size:
                        if (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + T.int64(4095)) // T.int64(4096) == T.int64(1):
                            if i_0 % T.int64(2) == T.int64(0):
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                i = T.allocate([T.int64(1)], "int64", "local")
                                j = T.allocate([T.int64(1)], "int64", "local")
                                i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]
                                j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - last_1[T.int64(0)]
                                for i_1_1 in range(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)) - threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))))):
                                    if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)):
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                    else:
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)):
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                            else:
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                i = T.allocate([T.int64(1)], "int64", "local")
                                j = T.allocate([T.int64(1)], "int64", "local")
                                i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]
                                j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - last_1[T.int64(0)]
                                for i_2 in range(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)) - threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))))):
                                    if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)):
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                    else:
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)):
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                        else:
                            if i_0 % T.int64(2) == T.int64(0):
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), blockIdx_x * T.int64(4096) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(blockIdx_x * T.int64(4096), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                if i_0 % T.int64(2) == T.int64(0):
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_3 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                else:
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_4 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                            else:
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), blockIdx_x * T.int64(4096) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(blockIdx_x * T.int64(4096), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                if i_0 % T.int64(2) == T.int64(0):
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_5 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                else:
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_6 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                if T.Cast("int64", T.ceil(T.log2(T.Cast("float64", vocab_size)))) > T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))) and (T.Cast("int64", T.ceil(T.log2(T.Cast("float64", vocab_size)))) - T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) % T.int64(2) == T.int64(1):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(1024))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(1023)) // T.int64(1024)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    if blockIdx_x * T.int64(1024) + threadIdx_x < vocab_size:
                        value_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size] = value_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size]
                        out_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size] = out_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size]

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(2, thread="blockIdx.y"):
                for ty in T.thread_binding(8, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(2, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, (fused_by_bz + 1) // 2 * 8 + ty, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 2 * 8 + fused_by_bz // 2 * 8 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((16, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((2, 8, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((2, 8, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((8,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 2
                                bz: T.int32 = fused_by_bz // 2
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by * 8 + bz * 8 + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 15) // 16):
                                    tile_start_s: T.int32 = tz * 8 + ty
                                    tile_start_g: T.int32 = (iterator * 2 + tz) * 8 + ty
                                    for j in range(1):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 8 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 2 + tz) * 8 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(8):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(8):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 8 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(2):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by * 8 + bz * 8 + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by * 8 + bz * 8 + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((128, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 64:j - 64 + 129])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[L_kv_base + cur_L, by, j + 64] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:128, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                            cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 2, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[cur_L, by, j + 64] * T.float16(-1.0), k[cur_L, by, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func(private=True)
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int64(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                    if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                        if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)))
        temp_max = T.alloc_buffer((batch_size, num_chunks))
        temp_sum = T.alloc_buffer((batch_size, num_chunks))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("pad"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2])
                T.writes(A_pad[v0, v1, v2])
                A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("max"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(A_pad[v0, v1, v2])
                T.writes(temp_max[v0, v1])
                with T.init():
                    temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("sum_exp"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1])
                T.writes(temp_sum[v0, v1])
                with T.init():
                    temp_sum[v0, v1] = T.float32(0.0)
                temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
            with T.block("log"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1])
                T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1])
                chunked_max[v0, v1] = temp_max[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding((batch_size * 256 + 1023) // 1024, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 256
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 128 % 2
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 128
                    if bhd_o * 1024 + bhd_i < batch_size * 2 * 128:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding((copy_length * T.int64(256) + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(2, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(128))))
                    vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(128)) // T.int64(128))
                    vd = T.axis.spatial(128, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(128)))
                    T.where(b * T.int64(1024) + T.Cast("int64", t) < copy_length * T.int64(2) * T.int64(128))
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for i in range(batch_size):
            with T.block("block"):
                vi = T.axis.spatial(batch_size, i)
                T.reads()
                T.writes(result[vi, 0])
                result[vi, 0] = value

    @T.prim_func(private=True)
    def fuse_add_norm_decode(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        A = T.match_buffer(pA, (batch_size, 1, 2048), "float16")
        B = T.match_buffer(pB, (batch_size, 1, 2048), "float16")
        O = T.match_buffer(pO, (batch_size, 1, 2048), "float16")
        add = T.match_buffer(pAdd, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
        for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(A[bx, 0, h], B[bx, 0, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        v_ax1 = T.axis.spatial(1, 0)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[bx, v_ax1, h])
                        add[bx, v_ax1, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, bx, 0])
                    sum_local[tx, bx, 0] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, bx, 0], add_local[i])
                        T.writes(sum_local[tx, bx, 0])
                        sum_local[tx, bx, 0] = sum_local[tx, bx, 0] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, bx, 0])
                        T.writes(sum_shared[bx, 0])
                        with T.init():
                            sum_shared[bx, 0] = T.float32(0.0)
                        sum_shared[bx, 0] = sum_shared[bx, 0] + sum_local[tx, bx, 0]
            for i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx_2)
                        T.reads(sum_shared[bx, 0], add_local[h // 1024], C[h])
                        T.writes(O[bx, 0, h])
                        O[bx, 0, h] = T.Cast("float16", T.rsqrt(sum_shared[bx, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[h // 1024]) * T.Cast("float32", C[h]))

    @T.prim_func(private=True)
    def fuse_add_norm_prefill(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        A = T.match_buffer(pA, (1, seq_len, 2048), "float16")
        B = T.match_buffer(pB, (1, seq_len, 2048), "float16")
        O = T.match_buffer(pO, (1, seq_len, 2048), "float16")
        add = T.match_buffer(pAdd, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
        sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
        for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for v_i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(A[0, bx, h], B[0, bx, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[0, bx, h])
                        add[0, bx, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, 0, bx])
                    sum_local[tx, 0, bx] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, 0, bx], add_local[i])
                        T.writes(sum_local[tx, 0, bx])
                        sum_local[tx, 0, bx] = sum_local[tx, 0, bx] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, 0, bx])
                        T.writes(sum_shared[0, bx])
                        with T.init():
                            sum_shared[0, bx] = T.float32(0.0)
                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
            for v_i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        v1 = T.axis.spatial(2048, v_i * 1024 + v_tx_2)
                        T.reads(sum_shared[0, bx], add_local[v1 // 1024], C[v1])
                        T.writes(O[0, bx, v1])
                        O[0, bx, v1] = T.Cast("float16", T.rsqrt(sum_shared[0, bx] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[v1 // 1024]) * T.Cast("float32", C[v1]))

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul10_add2(model_layers_0_self_attn_c_attn_q_weight2: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale2: T.Buffer((T.int64(2560), T.int64(64)), "float16"), rms_norm73: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), model_layers_0_self_attn_c_attn_bias2: T.Buffer((T.int64(2560),), "float16"), T_add_intermediate_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm73[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm73[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], model_layers_0_self_attn_c_attn_bias2[v_ax2])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + model_layers_0_self_attn_c_attn_bias2[v_ax2]

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul5_add1(model_layers_0_self_attn_c_attn_q_weight3: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale3: T.Buffer((T.int64(2560), T.int64(64)), "float16"), p_rms_norm146: T.handle, model_layers_0_self_attn_c_attn_bias3: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm146 = T.match_buffer(p_rms_norm146, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2560)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(2560)), "float16")
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2560), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm146[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm146[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2560)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], model_layers_0_self_attn_c_attn_bias3[v_ax2])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + model_layers_0_self_attn_c_attn_bias3[v_ax2]

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul_add(model_layers_0_self_attn_c_attn_q_weight4: T.Buffer((T.int64(2560), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale4: T.Buffer((T.int64(2560), T.int64(64)), "float16"), p_rms_norm219: T.handle, model_layers_0_self_attn_c_attn_bias4: T.Buffer((T.int64(2560),), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm219 = T.match_buffer(p_rms_norm219, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2560)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(2560)), "float16")
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2560), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2560), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm219[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm219[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2560)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], model_layers_0_self_attn_c_attn_bias4[v_ax2])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + model_layers_0_self_attn_c_attn_bias4[v_ax2]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul1(model_layers_0_self_attn_o_proj_q_weight4: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale4: T.Buffer((T.int64(2048), T.int64(64)), "float16"), p_reshape435: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape435 = T.match_buffer(p_reshape435, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape435[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape435[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul11(model_layers_0_self_attn_o_proj_q_weight2: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale2: T.Buffer((T.int64(2048), T.int64(64)), "float16"), lv218: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv218[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv218[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul6(model_layers_0_self_attn_o_proj_q_weight3: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale3: T.Buffer((T.int64(2048), T.int64(64)), "float16"), p_reshape291: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape291 = T.match_buffer(p_reshape291, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape291[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape291[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul12(model_layers_0_mlp_gate_up_proj_q_weight2: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale2: T.Buffer((T.int64(22016), T.int64(64)), "float16"), rms_norm74: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(22016)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_gate_up_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_gate_up_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(22016), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm74[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm74[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul2(model_layers_0_mlp_gate_up_proj_q_weight4: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale4: T.Buffer((T.int64(22016), T.int64(64)), "float16"), p_rms_norm220: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm220 = T.match_buffer(p_rms_norm220, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(22016)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_gate_up_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_gate_up_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(22016), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm220[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm220[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul7(model_layers_0_mlp_gate_up_proj_q_weight3: T.Buffer((T.int64(22016), T.int64(256)), "uint32"), model_layers_0_mlp_gate_up_proj_q_scale3: T.Buffer((T.int64(22016), T.int64(64)), "float16"), p_rms_norm147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm147 = T.match_buffer(p_rms_norm147, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(22016)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(22016), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_gate_up_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_gate_up_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(22016), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_gate_up_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_gate_up_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(22016), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm147[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm147[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul13(model_layers_0_mlp_down_proj_q_weight2: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale2: T.Buffer((T.int64(2048), T.int64(344)), "float16"), lv219: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_down_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_down_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(11008)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv219[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv219[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul3(model_layers_0_mlp_down_proj_q_weight4: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale4: T.Buffer((T.int64(2048), T.int64(344)), "float16"), p_lv1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv1 = T.match_buffer(p_lv1, (batch_size, T.int64(1), T.int64(11008)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_down_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_down_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(11008)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv1[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_NT_matmul8(model_layers_0_mlp_down_proj_q_weight3: T.Buffer((T.int64(2048), T.int64(1376)), "uint32"), model_layers_0_mlp_down_proj_q_scale3: T.Buffer((T.int64(2048), T.int64(344)), "float16"), p_lv73: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv73 = T.match_buffer(p_lv73, (T.int64(1), seq_len, T.int64(11008)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(11008)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_down_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_down_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(11008)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_down_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_down_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(11008)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv73[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv73[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize_NT_matmul14(model_embed_tokens_q_weight2: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale2: T.Buffer((T.int64(151936), T.int64(64)), "float16"), rms_norm145: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_embed_tokens_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_embed_tokens_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm145[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", rms_norm145[v_i0, v_i1, v_k]) * T.Cast("float32", dequantize_intermediate[v_i2, v_k])

    @T.prim_func(private=True)
    def fused_dequantize_NT_matmul4(model_embed_tokens_q_weight4: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale4: T.Buffer((T.int64(151936), T.int64(64)), "float16"), p_rms_norm291: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm291 = T.match_buffer(p_rms_norm291, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(151936)))
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_embed_tokens_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_embed_tokens_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm291[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", rms_norm291[v_i0, v_i1, v_k]) * T.Cast("float32", dequantize_intermediate[v_i2, v_k])

    @T.prim_func(private=True)
    def fused_dequantize_NT_matmul9(model_embed_tokens_q_weight3: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), model_embed_tokens_q_scale3: T.Buffer((T.int64(151936), T.int64(64)), "float16"), p_take1: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        take1 = T.match_buffer(p_take1, (T.int64(1), batch_size, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), batch_size, T.int64(151936)))
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_embed_tokens_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_embed_tokens_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), batch_size, T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(take1[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", take1[v_i0, v_i1, v_k]) * T.Cast("float32", dequantize_intermediate[v_i2, v_k])

    @T.prim_func(private=True)
    def fused_dequantize_take1(model_embed_tokens_q_weight: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale: T.Buffer((151936, 64), "float16"), p_input_ids: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int32()
        input_ids = T.match_buffer(p_input_ids, (seq_len,), "int32")
        T_take_intermediate = T.match_buffer(p_output0, (seq_len, 2048), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((151936, 2048), "float16")
        for i0, i1 in T.grid(151936, 2048):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight[v_i0, v_i1 // 8])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 4)), T.uint32(15)))
        for ax0, ax1 in T.grid(seq_len, 2048):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(input_ids[v_ax0], compute[input_ids[v_ax0], v_ax1], model_embed_tokens_q_scale[input_ids[v_ax0], v_ax1 // 32])
                T.writes(T_take_intermediate[v_ax0, v_ax1])
                T_take_intermediate[v_ax0, v_ax1] = (compute[input_ids[v_ax0], v_ax1] - T.float16(7.0)) * model_embed_tokens_q_scale[input_ids[v_ax0], v_ax1 // 32]

    @T.prim_func(private=True)
    def fused_reshape10_reshape11(lv184: T.Buffer((T.int64(1), T.int64(16), T.int64(128)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(16), T.int64(128)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(16), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv184[T.int64(0), (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv184[T.int64(0), (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)])
                T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def fused_reshape8_reshape9(add108: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(20), T.int64(128)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(20), T.int64(128)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(20), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(add108[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(2560)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = add108[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(2560)]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(20), T.int64(128)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(20), v_ax2 % T.int64(128)])
                T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(20), v_ax2 % T.int64(128)]

    @T.prim_func
    def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        qkv = T.match_buffer(var_qkv, (seq_len, 20, 128), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 16, 128), "float16")
        k = T.match_buffer(var_k, (seq_len, 2, 128), "float16")
        v = T.match_buffer(var_v, (seq_len, 2, 128), "float16")
        # with T.block("root"):
        for iters_0, iters_1, iters_2 in T.grid(seq_len, 20, 128):
            with T.block("llama_fused_rope"):
                s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2])
                T.reads(position_map[s], qkv[s, h, d - 64:d - 64 + 129])
                T.writes(q[s, h, d], k[s, h - 16, d], v[s, h - 18, d])
                if h < 16:
                    freq = T.float32()
                    q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                else:
                    if h < 18:
                        freq = T.float32()
                        k[s, h - 16, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                    else:
                        v[s, h - 18, d] = qkv[s, h, d]

    @T.prim_func(private=True)
    def fused_split1_silu1_multiply1(p_lv147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv147 = T.match_buffer(p_lv147, (T.int64(1), seq_len, T.int64(22016)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(11008)), "float16")
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(11008)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(11008)), "float16")
        compute = T.alloc_buffer((T.int64(1), seq_len, T.int64(11008)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(11008)), "float16")
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(11008)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv147[v_ax0, v_ax1, v_ax2])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] = lv147[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(11008)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv147[v_ax0, v_ax1, v_ax2 + T.int64(11008)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] = lv147[v_ax0, v_ax1, v_ax2 + T.int64(11008)]
        for i0, i1, i2 in T.grid(T.int64(1), seq_len, T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_split_sections_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute[v_i0, v_i1, v_i2])
                compute[v_i0, v_i1, v_i2] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1, v_i2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(11008)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(11008)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1, v_ax2], T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2] = T_multiply_intermediate[v_ax0, v_ax1, v_ax2] * T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2]

    @T.prim_func(private=True)
    def fused_split2_silu2_multiply2(lv437: T.Buffer((T.int64(1), T.int64(1), T.int64(22016)), "float16"), T_multiply_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
        compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16")
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv437[v_ax0, v_ax1, v_ax2])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] = lv437[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv437[v_ax0, v_ax1, v_ax2 + T.int64(11008)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] = lv437[v_ax0, v_ax1, v_ax2 + T.int64(11008)]
        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_split_sections_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute[v_i0, v_i1, v_i2])
                compute[v_i0, v_i1, v_i2] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1, v_i2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(11008)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1, v_ax2], T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2] = T_multiply_intermediate[v_ax0, v_ax1, v_ax2] * T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2]

    @T.prim_func(private=True)
    def fused_split_silu_multiply(p_lv2: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv2 = T.match_buffer(p_lv2, (batch_size, T.int64(1), T.int64(22016)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(11008)), "float16")
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(11008)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(11008)), "float16")
        compute = T.alloc_buffer((batch_size, T.int64(1), T.int64(11008)), "float16")
        T_multiply_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(11008)), "float16")
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(11008)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv2[v_ax0, v_ax1, v_ax2])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(11008)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv2[v_ax0, v_ax1, v_ax2 + T.int64(11008)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2] = lv2[v_ax0, v_ax1, v_ax2 + T.int64(11008)]
        for i0, i1, i2 in T.grid(batch_size, T.int64(1), T.int64(11008)):
            with T.block("compute"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(T_split_sections_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute[v_i0, v_i1, v_i2])
                compute[v_i0, v_i1, v_i2] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1, v_i2])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(11008)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1, v_ax2], compute[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = T_split_sections_intermediate[v_ax0, v_ax1, v_ax2] * compute[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(11008)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1, v_ax2], T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate_1[v_ax0, v_ax1, v_ax2] = T_multiply_intermediate[v_ax0, v_ax1, v_ax2] * T_split_sections_intermediate_1[v_ax0, v_ax1, v_ax2]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("gather_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[indices[vb], vj], indices[vb])
                T.writes(dst[vb, vj])
                dst[vb, vj] = src[indices[vb], vj]

    @T.prim_func(private=True)
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int64(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0, ax1 in T.grid(out_batch, vocab_size):
            with T.block("T_get_index_from_sorted"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
                T.writes(output_index[v_ax0, 0])
                if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size:
                    if v_ax1 == T.int64(0):
                        output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
                    else:
                        if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
                            output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

    @T.prim_func(private=True)
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch, vocab_size):
            with T.block("T_get_renorm_prob"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0])
                T.writes(renorm_prob[v_ax0, 0])
                if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > 1):
                    renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
                else:
                    if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0]):
                        if v_ax1 + T.int64(1) == vocab_size:
                            renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
                        else:
                            if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0])):
                                renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]

    @T.prim_func(private=True)
    def gpu_2d_continuous_cumsum(var_a: T.handle, var_out: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        m, n = T.int64(), T.int64()
        A = T.match_buffer(var_a, (m, n))
        Out = T.match_buffer(var_out, (m, n))
        # with T.block("root"):
        Tmp = T.alloc_buffer((m, n))
        ceil_log2: T.int64 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
        total_rounds: T.int64 = ceil_log2 // T.int64(9)
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                with T.block(""):
                    T.reads(A[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)])
                    T.writes(Out[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)], Tmp[by, bx])
                    local_buf = T.alloc_buffer((4,), scope="local")
                    shared_buf = T.alloc_buffer((T.int64(512),), scope="shared")
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                            for i in T.vectorized(T.int64(4)):
                                local_buf[i] = T.if_then_else(tx_idx + i < n, T.Cast("float32", A[by, tx_idx + i]), T.Cast("float32", 0))
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                local_buf[i] = local_buf[i] + local_buf[i - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                shared_buf[ty * T.int64(128) + tx * T.int64(4) + i] = local_buf[i]
                            for i in T.unroll(T.int64(5)):
                                for j in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                    if tx >= T.shift_left(T.int64(1), i):
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                for j in T.vectorized(T.int64(4)):
                                    if ty == T.int64(0):
                                        idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i * T.int64(128) - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i
                                if bx * T.int64(512) + idx < n:
                                    Out[by, bx * T.int64(512) + idx] = shared_buf[idx]
                            if tx == T.int64(0) and ty == T.int64(0):
                                for i in T.vectorized(T.int64(4)):
                                    Tmp[by, bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds):
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    with T.block(""):
                        T.reads(Tmp[by, bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)):bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(512)])
                        T.writes(Tmp[by, T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx):T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + (T.max(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(511), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + T.int64(1) - T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx))])
                        local_buf = T.alloc_buffer((4,), scope="local")
                        shared_buf = T.alloc_buffer((T.int64(512),), scope="shared")
                        for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                                for i_1 in T.vectorized(T.int64(4)):
                                    local_buf[i_1] = T.if_then_else(tx_idx + i_1 < cur_len, T.Cast("float32", Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + tx_idx + i_1]), T.Cast("float32", 0))
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    local_buf[i_1] = local_buf[i_1] + local_buf[i_1 - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    shared_buf[ty * T.int64(128) + tx * T.int64(4) + i_1] = local_buf[i_1]
                                for i_1 in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i_1):
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i_1) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i_1 * T.int64(128) + tx * T.int64(4)
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i_1 * T.int64(128) - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i_1
                                    if bx * T.int64(512) + idx < cur_len:
                                        Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx * T.int64(512) + idx] = shared_buf[idx]
                                if tx == T.int64(0) and ty == T.int64(0):
                                    for i_1 in T.vectorized(T.int64(4)):
                                        Tmp[by, (i + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds - T.int64(1)):
            real_idx: T.int64 = total_rounds - T.int64(1) - i - T.int64(1)
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            for i_1 in range(T.int64(4)):
                                idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i_1 * T.int64(32) + tx
                                if idx < cur_len:
                                    Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] = Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] + T.if_then_else(bx > T.int64(0), Tmp[by, (real_idx + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx - T.int64(1)], T.float32(0.0))
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        for i in range(T.int64(4)):
                            idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i * T.int64(32) + tx
                            if idx < n:
                                Out[by, idx] = Out[by, idx] + T.if_then_else(bx > T.int64(0), Tmp[by, bx - T.int64(1)], T.float32(0.0))

    @T.prim_func(private=True)
    def index(var_rms_norm72: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm72 = T.match_buffer(var_rms_norm72, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("index"):
                v_i, v__, v_k = T.axis.remap("SSS", [i, _, k])
                T.reads(rms_norm72[v_i, seq_len - T.int64(1), v_k])
                T.writes(index[v_i, v__, v_k])
                index[v_i, v__, v_k] = rms_norm72[v_i, seq_len - T.int64(1), v_k]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(16, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 16], S_other[bx, ty + by * 16], V[bx, ty + by * 16, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 16, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 16, tx * 4:tx * 4 + 4], S[bx, ty + by * 16])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 16]
                            s_other_val[0] = S_other[bx, ty + by * 16]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 16, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 16, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 16, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 16] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func
    def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        n, vocab_size = T.int64(), T.int64()
        prob = T.match_buffer(var_prob, (n, vocab_size))
        batch_size = T.int64()
        uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1))
        row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32")
        token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32")
        # with T.block("root"):
        aggregate = T.alloc_buffer((), scope="local")
        sample_id_local = T.alloc_buffer((), "int32", scope="local")
        step_iter = T.alloc_buffer((), "int32", scope="local")
        for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            row_idx: T.int32 = row_indices[bx, 0]
            for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    u: T.float32 = uniform_samples[bx, 0]
                    aggregate[()] = T.Cast("float32", 0)
                    step_iter[()] = 0
                    while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
                        with T.block(""):
                            T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()])
                            T.writes(sample_id_local[()], aggregate[()])
                            prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local")
                            cumsum = T.alloc_buffer((T.int64(512),), scope="shared")
                            greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            mask = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            valid = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            indices = T.alloc_buffer((T.int64(4),), "int32", scope="local")
                            step_aggregate = T.alloc_buffer((), scope="local")
                            for v in T.unroll(T.int64(4)):
                                idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v
                                prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0))
                                prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0.0), prob_local, T.Cast("float32", 0))
                                valid[v] = prob_local > T.float32(0.0) and idx < vocab_size
                            with T.block(""):
                                T.reads(prob_gt_threshold[T.int64(0):T.int64(4)])
                                T.writes(step_aggregate[()])
                                local_sum = T.alloc_buffer((), scope="local")
                                shared_buf = T.alloc_buffer((T.int64(128),), scope="shared")
                                idx: T.int64 = ty * T.int64(32) + tx
                                local_sum[()] = T.Cast("float32", 0)
                                for i in T.unroll(T.int64(4)):
                                    local_sum[()] = local_sum[()] + prob_gt_threshold[i]
                                shared_buf[idx] = local_sum[()]
                                for i in T.unroll(T.int64(7)):
                                    if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                        shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)]
                                step_aggregate[()] = shared_buf[0]
                            if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)):
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)]
                                for i in T.vectorized(T.int64(4)):
                                    cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i]
                                for i in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i):
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)]
                                for v in T.unroll(T.int64(4)):
                                    greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07)
                                with T.block(""):
                                    T.reads(greater_than_u[T.int64(0):T.int64(4)])
                                    T.writes(mask[T.int64(0):T.int64(4)])
                                    shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared")
                                    tx_idx: T.int64 = ty * T.int64(32) + tx
                                    shared_buf[tx_idx] = greater_than_u[T.int64(3)]
                                    mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0])
                                    for i in T.unroll(T.int64(1), T.int64(4)):
                                        mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)])
                                for v in T.unroll(T.int64(4)):
                                    mask[v] = mask[v] and valid[v]
                                    indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v)
                                with T.block(""):
                                    T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)])
                                    T.writes(sample_id_local[()])
                                    local_sum = T.alloc_buffer((), "int32", scope="local")
                                    shared_buf = T.alloc_buffer((T.int64(128),), "int32", scope="shared")
                                    idx: T.int64 = ty * T.int64(32) + tx
                                    local_sum[()] = T.Cast("int32", vocab_size - T.int64(1))
                                    for i in T.unroll(T.int64(4)):
                                        if mask[i]:
                                            local_sum[()] = T.min(local_sum[()], indices[i])
                                    shared_buf[idx] = local_sum[()]
                                    for i in T.unroll(T.int64(7)):
                                        if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                            shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)])
                                    sample_id_local[()] = shared_buf[0]
                            aggregate[()] = aggregate[()] + step_aggregate[()]
                        step_iter[()] = step_iter[()] + 1
                    if tx == T.int64(0) and ty == T.int64(0):
                        token_ids[bx, 0] = sample_id_local[()]

    @T.prim_func(private=True)
    def reshape(var_add324: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        add324 = T.match_buffer(var_add324, (batch_size, T.int64(1), T.int64(2560)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(20), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(add324[((v_ax2 * T.int64(128) + v_ax3) // T.int64(2560) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(2560)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = add324[((v_ax2 * T.int64(128) + v_ax3) // T.int64(2560) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(2560)]

    @T.prim_func(private=True)
    def reshape1(var_reshape432: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape432 = T.match_buffer(var_reshape432, (batch_size, T.int64(1), T.int64(20), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(20), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape432[((v_ax2 // T.int64(128) + v_ax1) // T.int64(20) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(20), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape432[((v_ax2 // T.int64(128) + v_ax1) // T.int64(20) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(20), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape2(var_lv546: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv546 = T.match_buffer(var_lv546, (batch_size, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(16), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(16), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv546[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv546[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape3(var_reshape434: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape434 = T.match_buffer(var_reshape434, (batch_size, T.int64(1), T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape434[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape434[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape4(var_add216: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        add216 = T.match_buffer(var_add216, (T.int64(1), seq_len, T.int64(2560)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(20), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(add216[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(2560) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(128) + v_ax3) % T.int64(2560)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = add216[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(2560) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(128) + v_ax3) % T.int64(2560)]

    @T.prim_func(private=True)
    def reshape5(var_reshape288: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape288 = T.match_buffer(var_reshape288, (T.int64(1), seq_len, T.int64(20), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(20), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(seq_len, T.int64(20), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape288[T.int64(0), ((v_ax2 // T.int64(128) + v_ax1) // T.int64(20) + v_ax0) % seq_len, (v_ax2 // T.int64(128) + v_ax1) % T.int64(20), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape288[T.int64(0), ((v_ax2 // T.int64(128) + v_ax1) // T.int64(20) + v_ax0) % seq_len, (v_ax2 // T.int64(128) + v_ax1) % T.int64(20), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape6(var_lv365: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv365 = T.match_buffer(var_lv365, (seq_len, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(16), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(16), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv365[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv365[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape7(var_reshape290: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape290 = T.match_buffer(var_reshape290, (T.int64(1), seq_len, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape290[T.int64(0), (v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape290[T.int64(0), (v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def rms_norm(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight4: T.Buffer((T.int64(2048),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_cast = T.match_buffer(var_T_cast, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)))
        T_multiply = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)))
        T_multiply_red = T.alloc_buffer((batch_size, T.int64(1)))
        rsqrt = T.alloc_buffer((batch_size, T.int64(1)))
        T_cast_2 = T.alloc_buffer((T.int64(2048),))
        T_rms_norm = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)))
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embeds[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(batch_size, T.int64(1)):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07))
        for ax0 in range(T.int64(2048)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(model_layers_0_input_layernorm_weight4[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", model_layers_0_input_layernorm_weight4[v_ax0])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func(private=True)
    def rms_norm1(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight3: T.Buffer((T.int64(2048),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_cast = T.match_buffer(var_T_cast, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)))
        T_multiply = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)))
        T_multiply_red = T.alloc_buffer((T.int64(1), seq_len))
        rsqrt = T.alloc_buffer((T.int64(1), seq_len))
        T_cast_2 = T.alloc_buffer((T.int64(2048),))
        T_rms_norm = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)))
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embeds[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(T.int64(1), seq_len):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07))
        for ax0 in range(T.int64(2048)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(model_layers_0_input_layernorm_weight3[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", model_layers_0_input_layernorm_weight3[v_ax0])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func(private=True)
    def rms_norm2(input_embed: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), model_layers_0_input_layernorm_weight2: T.Buffer((T.int64(2048),), "float16"), T_cast: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)))
        T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)))
        T_multiply_red = T.alloc_buffer((T.int64(1), T.int64(1)))
        rsqrt = T.alloc_buffer((T.int64(1), T.int64(1)))
        T_cast_2 = T.alloc_buffer((T.int64(2048),))
        T_rms_norm = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)))
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embed[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embed[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07))
        for ax0 in range(T.int64(2048)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(model_layers_0_input_layernorm_weight2[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", model_layers_0_input_layernorm_weight2[v_ax0])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for i in range(num_positions + num_samples):
            with T.block("block"):
                vi = T.axis.spatial(num_positions + num_samples, i)
                T.reads(top_prob_offsets[vi], sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], unsorted_probs[T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]):T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + (T.max(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + 1 - T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions])), T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]):T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + (T.max(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]))], sample_indices[vi - num_positions], sampling_results[vi - num_positions])
                T.writes(top_prob_indices[vi], top_prob_probs[vi], sampled_values[vi - num_positions])
                if vi < num_positions:
                    row: T.int32 = top_prob_offsets[vi] // vocab_size
                    col: T.int32 = top_prob_offsets[vi] % vocab_size
                    top_prob_indices[vi] = sorted_indices[row, col]
                    top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]]
                else:
                    vj: T.int32 = vi - num_positions
                    sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("scatter_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[vb, vj], indices[vb])
                T.writes(dst[indices[vb], vj])
                dst[indices[vb], vj] = src[vb, vj]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * T.int64(4096) + v2])
                            if v1 * T.int64(4096) + v2 < vocab_size:
                                softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func(private=True)
    def take(var_rms_norm218: T.handle, var_logit_positions: T.handle, var_T_take: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm218 = T.match_buffer(var_rms_norm218, (T.int64(1), seq_len, T.int64(2048)), "float16")
        batch_size = T.int64()
        logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32")
        T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), batch_size, T.int64(2048)):
            with T.block("T_take"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rms_norm218[v_ax0, logit_positions[v_ax1], v_ax2], logit_positions[v_ax1])
                T.writes(T_take[v_ax0, v_ax1, v_ax2])
                T_take[v_ax0, v_ax1, v_ax2] = rms_norm218[v_ax0, logit_positions[v_ax1], v_ax2]

    @T.prim_func(private=True)
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int64(is_size_var=True), T.int64(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for i, j in T.grid(batch_size_1, vocab_size_1):
            with T.block("take_sorted_probs"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(probs[v_i, lv1[v_i, v_j]], lv1[v_i, v_j])
                T.writes(take_sorted_probs[v_i, v_j])
                take_sorted_probs[v_i, v_j] = probs[v_i, lv1[v_i, v_j]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int64(), T.int64(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, page_size, 128), "float16", offset_factor=1)
        seqlen = T.int64(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (36, seqlen, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (36, seqlen, 2, 128), "float16")
        # with T.block("root"):
        for p, h, d in T.grid(seqlen, 2, 128):
            with T.block("copy0"):
                vp, vh, vd = T.axis.remap("SSS", [p, h, d])
                T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd])
                T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                position: T.int32 = position_map[vp]
                k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd]
                v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        num_pages = T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 2, 16, 128), "float16", offset_factor=1)
        ntoken = T.int64(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 2, 128), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 2, 128), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos, h, f in T.grid(ntoken, 2, 128):
            if position_map[global_pos] != -1:
                with T.block("k_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                with T.block("v_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func(private=True)
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func(private=True)
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 2, 16, 128), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(2, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 8
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = (q_indptr[b_idx + 1] - q_indptr[b_idx]) * 8
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i) // 8
                                                        cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = (LH_start + row) // 8
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min((LH_start + row) // 8 + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = (LH_start + row) // 8
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i) // 8, by * 8 + (LH_start + i) % 8])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // 8
                                                    cur_H_qo: T.int32 = by * 8 + (LH_start + i) % 8
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((2048, 2048), dtype="float16"):
        R.func_attr({"relax.memory_plan_dynamic_func_output": True})
        gv: R.Tensor((2048, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([2048, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global"))
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv1 = R.call_tir(cls.argsort, (probs,), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="int32"))
            lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1
            R.output(gv1)
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight4: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale4: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight4: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale4: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias4: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight4: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale4: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight4: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale4: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm219 = R.call_tir(cls.rms_norm, (input_embeds, model_layers_0_input_layernorm_weight4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_0_self_attn_c_attn_q_weight4, model_layers_0_self_attn_c_attn_q_scale4, rms_norm219, model_layers_0_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape432 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape433 = R.call_tir(cls.reshape1, (reshape432,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv546 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape433), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape434 = R.call_tir(cls.reshape2, (lv546,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape435 = R.call_tir(cls.reshape3, (reshape434,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_0_self_attn_o_proj_q_weight4, model_layers_0_self_attn_o_proj_q_scale4, reshape435), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_2 = R.call_tir(cls.fuse_add_norm_decode, (lv_1, input_embeds, model_layers_0_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_2[1]
            rms_norm220: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_2[0]
            lv1_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_0_mlp_gate_up_proj_q_weight4, model_layers_0_mlp_gate_up_proj_q_scale4, rms_norm220), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv1_2 = R.call_tir(cls.fused_split_silu_multiply, (lv1_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv2 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_0_mlp_down_proj_q_weight4, model_layers_0_mlp_down_proj_q_scale4, lv1_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv2_1 = R.call_tir(cls.fuse_add_norm_decode, (lv2, lv1, model_layers_1_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv3: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[1]
            rms_norm221: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[0]
            lv1_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_1_self_attn_c_attn_q_weight4, model_layers_1_self_attn_c_attn_q_scale4, rms_norm221, model_layers_1_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape436 = R.call_tir(cls.reshape, (lv1_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape437 = R.call_tir(cls.reshape1, (reshape436,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv551 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape437), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape438 = R.call_tir(cls.reshape2, (lv551,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape439 = R.call_tir(cls.reshape3, (reshape438,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv3_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_1_self_attn_o_proj_q_weight4, model_layers_1_self_attn_o_proj_q_scale4, reshape439), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv4 = R.call_tir(cls.fuse_add_norm_decode, (lv3_1, lv3, model_layers_1_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv5: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4[1]
            rms_norm222: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4[0]
            lv4_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_1_mlp_gate_up_proj_q_weight4, model_layers_1_mlp_gate_up_proj_q_scale4, rms_norm222), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv3_2 = R.call_tir(cls.fused_split_silu_multiply, (lv4_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv5_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_1_mlp_down_proj_q_weight4, model_layers_1_mlp_down_proj_q_scale4, lv3_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6 = R.call_tir(cls.fuse_add_norm_decode, (lv5_1, lv5, model_layers_2_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv7: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6[1]
            rms_norm223: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6[0]
            lv2_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_2_self_attn_c_attn_q_weight4, model_layers_2_self_attn_c_attn_q_scale4, rms_norm223, model_layers_2_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape440 = R.call_tir(cls.reshape, (lv2_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape441 = R.call_tir(cls.reshape1, (reshape440,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv556 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape441), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape442 = R.call_tir(cls.reshape2, (lv556,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape443 = R.call_tir(cls.reshape3, (reshape442,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_2_self_attn_o_proj_q_weight4, model_layers_2_self_attn_o_proj_q_scale4, reshape443), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv8 = R.call_tir(cls.fuse_add_norm_decode, (lv6_1, lv7, model_layers_2_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv9: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8[1]
            rms_norm224: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8[0]
            lv7_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_2_mlp_gate_up_proj_q_weight4, model_layers_2_mlp_gate_up_proj_q_scale4, rms_norm224), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv5_2 = R.call_tir(cls.fused_split_silu_multiply, (lv7_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv8_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_2_mlp_down_proj_q_weight4, model_layers_2_mlp_down_proj_q_scale4, lv5_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv10 = R.call_tir(cls.fuse_add_norm_decode, (lv8_1, lv9, model_layers_3_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv11: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10[1]
            rms_norm225: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10[0]
            lv3_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_3_self_attn_c_attn_q_weight4, model_layers_3_self_attn_c_attn_q_scale4, rms_norm225, model_layers_3_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape444 = R.call_tir(cls.reshape, (lv3_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape445 = R.call_tir(cls.reshape1, (reshape444,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv561 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape445), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape446 = R.call_tir(cls.reshape2, (lv561,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape447 = R.call_tir(cls.reshape3, (reshape446,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv9_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_3_self_attn_o_proj_q_weight4, model_layers_3_self_attn_o_proj_q_scale4, reshape447), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12 = R.call_tir(cls.fuse_add_norm_decode, (lv9_1, lv11, model_layers_3_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv13: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12[1]
            rms_norm226: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12[0]
            lv10_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_3_mlp_gate_up_proj_q_weight4, model_layers_3_mlp_gate_up_proj_q_scale4, rms_norm226), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv7_2 = R.call_tir(cls.fused_split_silu_multiply, (lv10_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv11_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_3_mlp_down_proj_q_weight4, model_layers_3_mlp_down_proj_q_scale4, lv7_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv14 = R.call_tir(cls.fuse_add_norm_decode, (lv11_1, lv13, model_layers_4_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv15: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14[1]
            rms_norm227: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14[0]
            lv4_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_4_self_attn_c_attn_q_weight4, model_layers_4_self_attn_c_attn_q_scale4, rms_norm227, model_layers_4_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape448 = R.call_tir(cls.reshape, (lv4_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape449 = R.call_tir(cls.reshape1, (reshape448,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv566 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape449), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape450 = R.call_tir(cls.reshape2, (lv566,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape451 = R.call_tir(cls.reshape3, (reshape450,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_4_self_attn_o_proj_q_weight4, model_layers_4_self_attn_o_proj_q_scale4, reshape451), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv16 = R.call_tir(cls.fuse_add_norm_decode, (lv12_1, lv15, model_layers_4_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv17: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16[1]
            rms_norm228: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16[0]
            lv13_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_4_mlp_gate_up_proj_q_weight4, model_layers_4_mlp_gate_up_proj_q_scale4, rms_norm228), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv9_2 = R.call_tir(cls.fused_split_silu_multiply, (lv13_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv14_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_4_mlp_down_proj_q_weight4, model_layers_4_mlp_down_proj_q_scale4, lv9_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18 = R.call_tir(cls.fuse_add_norm_decode, (lv14_1, lv17, model_layers_5_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv19: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18[1]
            rms_norm229: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18[0]
            lv5_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_5_self_attn_c_attn_q_weight4, model_layers_5_self_attn_c_attn_q_scale4, rms_norm229, model_layers_5_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape452 = R.call_tir(cls.reshape, (lv5_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape453 = R.call_tir(cls.reshape1, (reshape452,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv571 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape453), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape454 = R.call_tir(cls.reshape2, (lv571,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape455 = R.call_tir(cls.reshape3, (reshape454,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv15_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_5_self_attn_o_proj_q_weight4, model_layers_5_self_attn_o_proj_q_scale4, reshape455), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv20 = R.call_tir(cls.fuse_add_norm_decode, (lv15_1, lv19, model_layers_5_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv21: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20[1]
            rms_norm230: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20[0]
            lv16_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_5_mlp_gate_up_proj_q_weight4, model_layers_5_mlp_gate_up_proj_q_scale4, rms_norm230), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv11_2 = R.call_tir(cls.fused_split_silu_multiply, (lv16_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv17_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_5_mlp_down_proj_q_weight4, model_layers_5_mlp_down_proj_q_scale4, lv11_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv22 = R.call_tir(cls.fuse_add_norm_decode, (lv17_1, lv21, model_layers_6_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv23: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22[1]
            rms_norm231: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22[0]
            lv6_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_6_self_attn_c_attn_q_weight4, model_layers_6_self_attn_c_attn_q_scale4, rms_norm231, model_layers_6_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape456 = R.call_tir(cls.reshape, (lv6_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape457 = R.call_tir(cls.reshape1, (reshape456,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv576 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape457), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape458 = R.call_tir(cls.reshape2, (lv576,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape459 = R.call_tir(cls.reshape3, (reshape458,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_6_self_attn_o_proj_q_weight4, model_layers_6_self_attn_o_proj_q_scale4, reshape459), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24 = R.call_tir(cls.fuse_add_norm_decode, (lv18_1, lv23, model_layers_6_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv25: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24[1]
            rms_norm232: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24[0]
            lv19_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_6_mlp_gate_up_proj_q_weight4, model_layers_6_mlp_gate_up_proj_q_scale4, rms_norm232), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv13_2 = R.call_tir(cls.fused_split_silu_multiply, (lv19_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv20_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_6_mlp_down_proj_q_weight4, model_layers_6_mlp_down_proj_q_scale4, lv13_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv26 = R.call_tir(cls.fuse_add_norm_decode, (lv20_1, lv25, model_layers_7_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv27: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26[1]
            rms_norm233: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26[0]
            lv7_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_7_self_attn_c_attn_q_weight4, model_layers_7_self_attn_c_attn_q_scale4, rms_norm233, model_layers_7_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape460 = R.call_tir(cls.reshape, (lv7_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape461 = R.call_tir(cls.reshape1, (reshape460,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv581 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape461), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape462 = R.call_tir(cls.reshape2, (lv581,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape463 = R.call_tir(cls.reshape3, (reshape462,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv21_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_7_self_attn_o_proj_q_weight4, model_layers_7_self_attn_o_proj_q_scale4, reshape463), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv28 = R.call_tir(cls.fuse_add_norm_decode, (lv21_1, lv27, model_layers_7_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv29: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28[1]
            rms_norm234: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28[0]
            lv22_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_7_mlp_gate_up_proj_q_weight4, model_layers_7_mlp_gate_up_proj_q_scale4, rms_norm234), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv15_2 = R.call_tir(cls.fused_split_silu_multiply, (lv22_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv23_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_7_mlp_down_proj_q_weight4, model_layers_7_mlp_down_proj_q_scale4, lv15_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30 = R.call_tir(cls.fuse_add_norm_decode, (lv23_1, lv29, model_layers_8_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv31: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30[1]
            rms_norm235: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30[0]
            lv8_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_8_self_attn_c_attn_q_weight4, model_layers_8_self_attn_c_attn_q_scale4, rms_norm235, model_layers_8_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape464 = R.call_tir(cls.reshape, (lv8_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape465 = R.call_tir(cls.reshape1, (reshape464,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv586 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape465), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape466 = R.call_tir(cls.reshape2, (lv586,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape467 = R.call_tir(cls.reshape3, (reshape466,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_8_self_attn_o_proj_q_weight4, model_layers_8_self_attn_o_proj_q_scale4, reshape467), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv32 = R.call_tir(cls.fuse_add_norm_decode, (lv24_1, lv31, model_layers_8_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv33: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32[1]
            rms_norm236: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32[0]
            lv25_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_8_mlp_gate_up_proj_q_weight4, model_layers_8_mlp_gate_up_proj_q_scale4, rms_norm236), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv17_2 = R.call_tir(cls.fused_split_silu_multiply, (lv25_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv26_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_8_mlp_down_proj_q_weight4, model_layers_8_mlp_down_proj_q_scale4, lv17_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv34 = R.call_tir(cls.fuse_add_norm_decode, (lv26_1, lv33, model_layers_9_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv35: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34[1]
            rms_norm237: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34[0]
            lv9_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_9_self_attn_c_attn_q_weight4, model_layers_9_self_attn_c_attn_q_scale4, rms_norm237, model_layers_9_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape468 = R.call_tir(cls.reshape, (lv9_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape469 = R.call_tir(cls.reshape1, (reshape468,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv591 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape469), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape470 = R.call_tir(cls.reshape2, (lv591,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape471 = R.call_tir(cls.reshape3, (reshape470,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv27_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_9_self_attn_o_proj_q_weight4, model_layers_9_self_attn_o_proj_q_scale4, reshape471), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36 = R.call_tir(cls.fuse_add_norm_decode, (lv27_1, lv35, model_layers_9_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv37: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36[1]
            rms_norm238: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36[0]
            lv28_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_9_mlp_gate_up_proj_q_weight4, model_layers_9_mlp_gate_up_proj_q_scale4, rms_norm238), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv19_2 = R.call_tir(cls.fused_split_silu_multiply, (lv28_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv29_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_9_mlp_down_proj_q_weight4, model_layers_9_mlp_down_proj_q_scale4, lv19_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv38 = R.call_tir(cls.fuse_add_norm_decode, (lv29_1, lv37, model_layers_10_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv39: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38[1]
            rms_norm239: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38[0]
            lv10_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_10_self_attn_c_attn_q_weight4, model_layers_10_self_attn_c_attn_q_scale4, rms_norm239, model_layers_10_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape472 = R.call_tir(cls.reshape, (lv10_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape473 = R.call_tir(cls.reshape1, (reshape472,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv596 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape473), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape474 = R.call_tir(cls.reshape2, (lv596,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape475 = R.call_tir(cls.reshape3, (reshape474,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_10_self_attn_o_proj_q_weight4, model_layers_10_self_attn_o_proj_q_scale4, reshape475), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv40 = R.call_tir(cls.fuse_add_norm_decode, (lv30_1, lv39, model_layers_10_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv41: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40[1]
            rms_norm240: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40[0]
            lv31_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_10_mlp_gate_up_proj_q_weight4, model_layers_10_mlp_gate_up_proj_q_scale4, rms_norm240), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv21_2 = R.call_tir(cls.fused_split_silu_multiply, (lv31_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv32_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_10_mlp_down_proj_q_weight4, model_layers_10_mlp_down_proj_q_scale4, lv21_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42 = R.call_tir(cls.fuse_add_norm_decode, (lv32_1, lv41, model_layers_11_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv43: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42[1]
            rms_norm241: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42[0]
            lv11_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_11_self_attn_c_attn_q_weight4, model_layers_11_self_attn_c_attn_q_scale4, rms_norm241, model_layers_11_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape476 = R.call_tir(cls.reshape, (lv11_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape477 = R.call_tir(cls.reshape1, (reshape476,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv601 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape477), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape478 = R.call_tir(cls.reshape2, (lv601,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape479 = R.call_tir(cls.reshape3, (reshape478,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv33_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_11_self_attn_o_proj_q_weight4, model_layers_11_self_attn_o_proj_q_scale4, reshape479), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv44 = R.call_tir(cls.fuse_add_norm_decode, (lv33_1, lv43, model_layers_11_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv45: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44[1]
            rms_norm242: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44[0]
            lv34_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_11_mlp_gate_up_proj_q_weight4, model_layers_11_mlp_gate_up_proj_q_scale4, rms_norm242), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv23_2 = R.call_tir(cls.fused_split_silu_multiply, (lv34_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv35_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_11_mlp_down_proj_q_weight4, model_layers_11_mlp_down_proj_q_scale4, lv23_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv46 = R.call_tir(cls.fuse_add_norm_decode, (lv35_1, lv45, model_layers_12_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv47: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46[1]
            rms_norm243: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46[0]
            lv12_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_12_self_attn_c_attn_q_weight4, model_layers_12_self_attn_c_attn_q_scale4, rms_norm243, model_layers_12_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape480 = R.call_tir(cls.reshape, (lv12_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape481 = R.call_tir(cls.reshape1, (reshape480,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv606 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape481), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape482 = R.call_tir(cls.reshape2, (lv606,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape483 = R.call_tir(cls.reshape3, (reshape482,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_12_self_attn_o_proj_q_weight4, model_layers_12_self_attn_o_proj_q_scale4, reshape483), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv48 = R.call_tir(cls.fuse_add_norm_decode, (lv36_1, lv47, model_layers_12_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv49: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48[1]
            rms_norm244: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48[0]
            lv37_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_12_mlp_gate_up_proj_q_weight4, model_layers_12_mlp_gate_up_proj_q_scale4, rms_norm244), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv25_2 = R.call_tir(cls.fused_split_silu_multiply, (lv37_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv38_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_12_mlp_down_proj_q_weight4, model_layers_12_mlp_down_proj_q_scale4, lv25_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv50 = R.call_tir(cls.fuse_add_norm_decode, (lv38_1, lv49, model_layers_13_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv51: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50[1]
            rms_norm245: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50[0]
            lv13_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_13_self_attn_c_attn_q_weight4, model_layers_13_self_attn_c_attn_q_scale4, rms_norm245, model_layers_13_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape484 = R.call_tir(cls.reshape, (lv13_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape485 = R.call_tir(cls.reshape1, (reshape484,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv611 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape485), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape486 = R.call_tir(cls.reshape2, (lv611,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape487 = R.call_tir(cls.reshape3, (reshape486,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv39_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_13_self_attn_o_proj_q_weight4, model_layers_13_self_attn_o_proj_q_scale4, reshape487), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv52 = R.call_tir(cls.fuse_add_norm_decode, (lv39_1, lv51, model_layers_13_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv53: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52[1]
            rms_norm246: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52[0]
            lv40_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_13_mlp_gate_up_proj_q_weight4, model_layers_13_mlp_gate_up_proj_q_scale4, rms_norm246), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv27_2 = R.call_tir(cls.fused_split_silu_multiply, (lv40_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv41_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_13_mlp_down_proj_q_weight4, model_layers_13_mlp_down_proj_q_scale4, lv27_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv54 = R.call_tir(cls.fuse_add_norm_decode, (lv41_1, lv53, model_layers_14_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv55: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54[1]
            rms_norm247: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54[0]
            lv14_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_14_self_attn_c_attn_q_weight4, model_layers_14_self_attn_c_attn_q_scale4, rms_norm247, model_layers_14_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape488 = R.call_tir(cls.reshape, (lv14_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape489 = R.call_tir(cls.reshape1, (reshape488,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv616 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape489), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape490 = R.call_tir(cls.reshape2, (lv616,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape491 = R.call_tir(cls.reshape3, (reshape490,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_14_self_attn_o_proj_q_weight4, model_layers_14_self_attn_o_proj_q_scale4, reshape491), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv56 = R.call_tir(cls.fuse_add_norm_decode, (lv42_1, lv55, model_layers_14_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv57: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56[1]
            rms_norm248: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56[0]
            lv43_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_14_mlp_gate_up_proj_q_weight4, model_layers_14_mlp_gate_up_proj_q_scale4, rms_norm248), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv29_2 = R.call_tir(cls.fused_split_silu_multiply, (lv43_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv44_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_14_mlp_down_proj_q_weight4, model_layers_14_mlp_down_proj_q_scale4, lv29_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv58 = R.call_tir(cls.fuse_add_norm_decode, (lv44_1, lv57, model_layers_15_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv59: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58[1]
            rms_norm249: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58[0]
            lv15_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_15_self_attn_c_attn_q_weight4, model_layers_15_self_attn_c_attn_q_scale4, rms_norm249, model_layers_15_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape492 = R.call_tir(cls.reshape, (lv15_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape493 = R.call_tir(cls.reshape1, (reshape492,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv621 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape493), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape494 = R.call_tir(cls.reshape2, (lv621,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape495 = R.call_tir(cls.reshape3, (reshape494,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv45_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_15_self_attn_o_proj_q_weight4, model_layers_15_self_attn_o_proj_q_scale4, reshape495), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv60 = R.call_tir(cls.fuse_add_norm_decode, (lv45_1, lv59, model_layers_15_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv61: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60[1]
            rms_norm250: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60[0]
            lv46_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_15_mlp_gate_up_proj_q_weight4, model_layers_15_mlp_gate_up_proj_q_scale4, rms_norm250), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv31_2 = R.call_tir(cls.fused_split_silu_multiply, (lv46_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv47_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_15_mlp_down_proj_q_weight4, model_layers_15_mlp_down_proj_q_scale4, lv31_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv62 = R.call_tir(cls.fuse_add_norm_decode, (lv47_1, lv61, model_layers_16_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv63: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62[1]
            rms_norm251: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62[0]
            lv16_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_16_self_attn_c_attn_q_weight4, model_layers_16_self_attn_c_attn_q_scale4, rms_norm251, model_layers_16_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape496 = R.call_tir(cls.reshape, (lv16_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape497 = R.call_tir(cls.reshape1, (reshape496,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv626 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape497), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape498 = R.call_tir(cls.reshape2, (lv626,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape499 = R.call_tir(cls.reshape3, (reshape498,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv48_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_16_self_attn_o_proj_q_weight4, model_layers_16_self_attn_o_proj_q_scale4, reshape499), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv64 = R.call_tir(cls.fuse_add_norm_decode, (lv48_1, lv63, model_layers_16_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv65: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64[1]
            rms_norm252: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64[0]
            lv49_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_16_mlp_gate_up_proj_q_weight4, model_layers_16_mlp_gate_up_proj_q_scale4, rms_norm252), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv33_2 = R.call_tir(cls.fused_split_silu_multiply, (lv49_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv50_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_16_mlp_down_proj_q_weight4, model_layers_16_mlp_down_proj_q_scale4, lv33_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv66 = R.call_tir(cls.fuse_add_norm_decode, (lv50_1, lv65, model_layers_17_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv67: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66[1]
            rms_norm253: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66[0]
            lv17_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_17_self_attn_c_attn_q_weight4, model_layers_17_self_attn_c_attn_q_scale4, rms_norm253, model_layers_17_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape500 = R.call_tir(cls.reshape, (lv17_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape501 = R.call_tir(cls.reshape1, (reshape500,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv631 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape501), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape502 = R.call_tir(cls.reshape2, (lv631,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape503 = R.call_tir(cls.reshape3, (reshape502,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv51_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_17_self_attn_o_proj_q_weight4, model_layers_17_self_attn_o_proj_q_scale4, reshape503), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv68 = R.call_tir(cls.fuse_add_norm_decode, (lv51_1, lv67, model_layers_17_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv69: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68[1]
            rms_norm254: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68[0]
            lv52_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_17_mlp_gate_up_proj_q_weight4, model_layers_17_mlp_gate_up_proj_q_scale4, rms_norm254), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv35_2 = R.call_tir(cls.fused_split_silu_multiply, (lv52_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv53_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_17_mlp_down_proj_q_weight4, model_layers_17_mlp_down_proj_q_scale4, lv35_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv70 = R.call_tir(cls.fuse_add_norm_decode, (lv53_1, lv69, model_layers_18_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv71: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70[1]
            rms_norm255: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70[0]
            lv18_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_18_self_attn_c_attn_q_weight4, model_layers_18_self_attn_c_attn_q_scale4, rms_norm255, model_layers_18_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape504 = R.call_tir(cls.reshape, (lv18_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape505 = R.call_tir(cls.reshape1, (reshape504,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv636 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape505), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape506 = R.call_tir(cls.reshape2, (lv636,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape507 = R.call_tir(cls.reshape3, (reshape506,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv54_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_18_self_attn_o_proj_q_weight4, model_layers_18_self_attn_o_proj_q_scale4, reshape507), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv72 = R.call_tir(cls.fuse_add_norm_decode, (lv54_1, lv71, model_layers_18_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv73: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72[1]
            rms_norm256: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72[0]
            lv55_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_18_mlp_gate_up_proj_q_weight4, model_layers_18_mlp_gate_up_proj_q_scale4, rms_norm256), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv37_2 = R.call_tir(cls.fused_split_silu_multiply, (lv55_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv56_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_18_mlp_down_proj_q_weight4, model_layers_18_mlp_down_proj_q_scale4, lv37_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv74 = R.call_tir(cls.fuse_add_norm_decode, (lv56_1, lv73, model_layers_19_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv75: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74[1]
            rms_norm257: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74[0]
            lv19_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_19_self_attn_c_attn_q_weight4, model_layers_19_self_attn_c_attn_q_scale4, rms_norm257, model_layers_19_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape508 = R.call_tir(cls.reshape, (lv19_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape509 = R.call_tir(cls.reshape1, (reshape508,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv641 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape509), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape510 = R.call_tir(cls.reshape2, (lv641,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape511 = R.call_tir(cls.reshape3, (reshape510,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv57_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_19_self_attn_o_proj_q_weight4, model_layers_19_self_attn_o_proj_q_scale4, reshape511), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv76 = R.call_tir(cls.fuse_add_norm_decode, (lv57_1, lv75, model_layers_19_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv77: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76[1]
            rms_norm258: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76[0]
            lv58_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_19_mlp_gate_up_proj_q_weight4, model_layers_19_mlp_gate_up_proj_q_scale4, rms_norm258), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv39_2 = R.call_tir(cls.fused_split_silu_multiply, (lv58_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv59_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_19_mlp_down_proj_q_weight4, model_layers_19_mlp_down_proj_q_scale4, lv39_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv78 = R.call_tir(cls.fuse_add_norm_decode, (lv59_1, lv77, model_layers_20_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv79: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78[1]
            rms_norm259: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78[0]
            lv20_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_20_self_attn_c_attn_q_weight4, model_layers_20_self_attn_c_attn_q_scale4, rms_norm259, model_layers_20_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape512 = R.call_tir(cls.reshape, (lv20_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape513 = R.call_tir(cls.reshape1, (reshape512,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv646 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape513), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape514 = R.call_tir(cls.reshape2, (lv646,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape515 = R.call_tir(cls.reshape3, (reshape514,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv60_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_20_self_attn_o_proj_q_weight4, model_layers_20_self_attn_o_proj_q_scale4, reshape515), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv80 = R.call_tir(cls.fuse_add_norm_decode, (lv60_1, lv79, model_layers_20_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv81: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80[1]
            rms_norm260: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80[0]
            lv61_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_20_mlp_gate_up_proj_q_weight4, model_layers_20_mlp_gate_up_proj_q_scale4, rms_norm260), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv41_2 = R.call_tir(cls.fused_split_silu_multiply, (lv61_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv62_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_20_mlp_down_proj_q_weight4, model_layers_20_mlp_down_proj_q_scale4, lv41_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv82 = R.call_tir(cls.fuse_add_norm_decode, (lv62_1, lv81, model_layers_21_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv83: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82[1]
            rms_norm261: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82[0]
            lv21_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_21_self_attn_c_attn_q_weight4, model_layers_21_self_attn_c_attn_q_scale4, rms_norm261, model_layers_21_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape516 = R.call_tir(cls.reshape, (lv21_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape517 = R.call_tir(cls.reshape1, (reshape516,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv651 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape517), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape518 = R.call_tir(cls.reshape2, (lv651,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape519 = R.call_tir(cls.reshape3, (reshape518,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv63_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_21_self_attn_o_proj_q_weight4, model_layers_21_self_attn_o_proj_q_scale4, reshape519), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv84 = R.call_tir(cls.fuse_add_norm_decode, (lv63_1, lv83, model_layers_21_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv85: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84[1]
            rms_norm262: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84[0]
            lv64_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_21_mlp_gate_up_proj_q_weight4, model_layers_21_mlp_gate_up_proj_q_scale4, rms_norm262), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv43_2 = R.call_tir(cls.fused_split_silu_multiply, (lv64_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv65_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_21_mlp_down_proj_q_weight4, model_layers_21_mlp_down_proj_q_scale4, lv43_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv86 = R.call_tir(cls.fuse_add_norm_decode, (lv65_1, lv85, model_layers_22_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv87: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86[1]
            rms_norm263: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86[0]
            lv22_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_22_self_attn_c_attn_q_weight4, model_layers_22_self_attn_c_attn_q_scale4, rms_norm263, model_layers_22_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape520 = R.call_tir(cls.reshape, (lv22_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape521 = R.call_tir(cls.reshape1, (reshape520,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv656 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape521), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape522 = R.call_tir(cls.reshape2, (lv656,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape523 = R.call_tir(cls.reshape3, (reshape522,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv66_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_22_self_attn_o_proj_q_weight4, model_layers_22_self_attn_o_proj_q_scale4, reshape523), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv88 = R.call_tir(cls.fuse_add_norm_decode, (lv66_1, lv87, model_layers_22_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv89: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88[1]
            rms_norm264: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88[0]
            lv67_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_22_mlp_gate_up_proj_q_weight4, model_layers_22_mlp_gate_up_proj_q_scale4, rms_norm264), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv45_2 = R.call_tir(cls.fused_split_silu_multiply, (lv67_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv68_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_22_mlp_down_proj_q_weight4, model_layers_22_mlp_down_proj_q_scale4, lv45_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv90 = R.call_tir(cls.fuse_add_norm_decode, (lv68_1, lv89, model_layers_23_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv91: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90[1]
            rms_norm265: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90[0]
            lv23_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_23_self_attn_c_attn_q_weight4, model_layers_23_self_attn_c_attn_q_scale4, rms_norm265, model_layers_23_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape524 = R.call_tir(cls.reshape, (lv23_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape525 = R.call_tir(cls.reshape1, (reshape524,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv661 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape525), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape526 = R.call_tir(cls.reshape2, (lv661,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape527 = R.call_tir(cls.reshape3, (reshape526,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv69_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_23_self_attn_o_proj_q_weight4, model_layers_23_self_attn_o_proj_q_scale4, reshape527), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv92 = R.call_tir(cls.fuse_add_norm_decode, (lv69_1, lv91, model_layers_23_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv93: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92[1]
            rms_norm266: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92[0]
            lv70_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_23_mlp_gate_up_proj_q_weight4, model_layers_23_mlp_gate_up_proj_q_scale4, rms_norm266), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv47_2 = R.call_tir(cls.fused_split_silu_multiply, (lv70_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv71_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_23_mlp_down_proj_q_weight4, model_layers_23_mlp_down_proj_q_scale4, lv47_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv94 = R.call_tir(cls.fuse_add_norm_decode, (lv71_1, lv93, model_layers_24_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv95: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94[1]
            rms_norm267: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94[0]
            lv24_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_24_self_attn_c_attn_q_weight4, model_layers_24_self_attn_c_attn_q_scale4, rms_norm267, model_layers_24_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape528 = R.call_tir(cls.reshape, (lv24_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape529 = R.call_tir(cls.reshape1, (reshape528,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv666 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape529), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape530 = R.call_tir(cls.reshape2, (lv666,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape531 = R.call_tir(cls.reshape3, (reshape530,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv72_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_24_self_attn_o_proj_q_weight4, model_layers_24_self_attn_o_proj_q_scale4, reshape531), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv96 = R.call_tir(cls.fuse_add_norm_decode, (lv72_1, lv95, model_layers_24_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv97: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv96[1]
            rms_norm268: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv96[0]
            lv73_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_24_mlp_gate_up_proj_q_weight4, model_layers_24_mlp_gate_up_proj_q_scale4, rms_norm268), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv49_2 = R.call_tir(cls.fused_split_silu_multiply, (lv73_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv74_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_24_mlp_down_proj_q_weight4, model_layers_24_mlp_down_proj_q_scale4, lv49_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv98 = R.call_tir(cls.fuse_add_norm_decode, (lv74_1, lv97, model_layers_25_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv99: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv98[1]
            rms_norm269: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv98[0]
            lv25_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_25_self_attn_c_attn_q_weight4, model_layers_25_self_attn_c_attn_q_scale4, rms_norm269, model_layers_25_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape532 = R.call_tir(cls.reshape, (lv25_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape533 = R.call_tir(cls.reshape1, (reshape532,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv671 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape533), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape534 = R.call_tir(cls.reshape2, (lv671,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape535 = R.call_tir(cls.reshape3, (reshape534,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv75_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_25_self_attn_o_proj_q_weight4, model_layers_25_self_attn_o_proj_q_scale4, reshape535), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv100 = R.call_tir(cls.fuse_add_norm_decode, (lv75_1, lv99, model_layers_25_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv101: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv100[1]
            rms_norm270: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv100[0]
            lv76_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_25_mlp_gate_up_proj_q_weight4, model_layers_25_mlp_gate_up_proj_q_scale4, rms_norm270), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv51_2 = R.call_tir(cls.fused_split_silu_multiply, (lv76_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv77_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_25_mlp_down_proj_q_weight4, model_layers_25_mlp_down_proj_q_scale4, lv51_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv102 = R.call_tir(cls.fuse_add_norm_decode, (lv77_1, lv101, model_layers_26_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv103: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv102[1]
            rms_norm271: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv102[0]
            lv26_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_26_self_attn_c_attn_q_weight4, model_layers_26_self_attn_c_attn_q_scale4, rms_norm271, model_layers_26_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape536 = R.call_tir(cls.reshape, (lv26_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape537 = R.call_tir(cls.reshape1, (reshape536,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv676 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape537), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape538 = R.call_tir(cls.reshape2, (lv676,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape539 = R.call_tir(cls.reshape3, (reshape538,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv78_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_26_self_attn_o_proj_q_weight4, model_layers_26_self_attn_o_proj_q_scale4, reshape539), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv104 = R.call_tir(cls.fuse_add_norm_decode, (lv78_1, lv103, model_layers_26_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv105: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv104[1]
            rms_norm272: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv104[0]
            lv79_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_26_mlp_gate_up_proj_q_weight4, model_layers_26_mlp_gate_up_proj_q_scale4, rms_norm272), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv53_2 = R.call_tir(cls.fused_split_silu_multiply, (lv79_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv80_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_26_mlp_down_proj_q_weight4, model_layers_26_mlp_down_proj_q_scale4, lv53_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv106 = R.call_tir(cls.fuse_add_norm_decode, (lv80_1, lv105, model_layers_27_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv107: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv106[1]
            rms_norm273: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv106[0]
            lv27_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_27_self_attn_c_attn_q_weight4, model_layers_27_self_attn_c_attn_q_scale4, rms_norm273, model_layers_27_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape540 = R.call_tir(cls.reshape, (lv27_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape541 = R.call_tir(cls.reshape1, (reshape540,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv681 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape541), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape542 = R.call_tir(cls.reshape2, (lv681,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape543 = R.call_tir(cls.reshape3, (reshape542,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv81_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_27_self_attn_o_proj_q_weight4, model_layers_27_self_attn_o_proj_q_scale4, reshape543), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv108 = R.call_tir(cls.fuse_add_norm_decode, (lv81_1, lv107, model_layers_27_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv109: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv108[1]
            rms_norm274: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv108[0]
            lv82_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_27_mlp_gate_up_proj_q_weight4, model_layers_27_mlp_gate_up_proj_q_scale4, rms_norm274), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv55_2 = R.call_tir(cls.fused_split_silu_multiply, (lv82_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv83_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_27_mlp_down_proj_q_weight4, model_layers_27_mlp_down_proj_q_scale4, lv55_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv110 = R.call_tir(cls.fuse_add_norm_decode, (lv83_1, lv109, model_layers_28_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv111: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv110[1]
            rms_norm275: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv110[0]
            lv28_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_28_self_attn_c_attn_q_weight4, model_layers_28_self_attn_c_attn_q_scale4, rms_norm275, model_layers_28_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape544 = R.call_tir(cls.reshape, (lv28_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape545 = R.call_tir(cls.reshape1, (reshape544,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv686 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape545), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape546 = R.call_tir(cls.reshape2, (lv686,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape547 = R.call_tir(cls.reshape3, (reshape546,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv84_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_28_self_attn_o_proj_q_weight4, model_layers_28_self_attn_o_proj_q_scale4, reshape547), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv112 = R.call_tir(cls.fuse_add_norm_decode, (lv84_1, lv111, model_layers_28_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv113: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv112[1]
            rms_norm276: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv112[0]
            lv85_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_28_mlp_gate_up_proj_q_weight4, model_layers_28_mlp_gate_up_proj_q_scale4, rms_norm276), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv57_2 = R.call_tir(cls.fused_split_silu_multiply, (lv85_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv86_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_28_mlp_down_proj_q_weight4, model_layers_28_mlp_down_proj_q_scale4, lv57_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv114 = R.call_tir(cls.fuse_add_norm_decode, (lv86_1, lv113, model_layers_29_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv115: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv114[1]
            rms_norm277: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv114[0]
            lv29_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_29_self_attn_c_attn_q_weight4, model_layers_29_self_attn_c_attn_q_scale4, rms_norm277, model_layers_29_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape548 = R.call_tir(cls.reshape, (lv29_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape549 = R.call_tir(cls.reshape1, (reshape548,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv691 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape549), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape550 = R.call_tir(cls.reshape2, (lv691,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape551 = R.call_tir(cls.reshape3, (reshape550,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv87_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_29_self_attn_o_proj_q_weight4, model_layers_29_self_attn_o_proj_q_scale4, reshape551), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv116 = R.call_tir(cls.fuse_add_norm_decode, (lv87_1, lv115, model_layers_29_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv117: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv116[1]
            rms_norm278: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv116[0]
            lv88_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_29_mlp_gate_up_proj_q_weight4, model_layers_29_mlp_gate_up_proj_q_scale4, rms_norm278), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv59_2 = R.call_tir(cls.fused_split_silu_multiply, (lv88_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv89_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_29_mlp_down_proj_q_weight4, model_layers_29_mlp_down_proj_q_scale4, lv59_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv118 = R.call_tir(cls.fuse_add_norm_decode, (lv89_1, lv117, model_layers_30_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv119: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv118[1]
            rms_norm279: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv118[0]
            lv30_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_30_self_attn_c_attn_q_weight4, model_layers_30_self_attn_c_attn_q_scale4, rms_norm279, model_layers_30_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape552 = R.call_tir(cls.reshape, (lv30_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape553 = R.call_tir(cls.reshape1, (reshape552,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv696 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape553), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape554 = R.call_tir(cls.reshape2, (lv696,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape555 = R.call_tir(cls.reshape3, (reshape554,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv90_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_30_self_attn_o_proj_q_weight4, model_layers_30_self_attn_o_proj_q_scale4, reshape555), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv120 = R.call_tir(cls.fuse_add_norm_decode, (lv90_1, lv119, model_layers_30_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv121: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv120[1]
            rms_norm280: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv120[0]
            lv91_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_30_mlp_gate_up_proj_q_weight4, model_layers_30_mlp_gate_up_proj_q_scale4, rms_norm280), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv61_2 = R.call_tir(cls.fused_split_silu_multiply, (lv91_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv92_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_30_mlp_down_proj_q_weight4, model_layers_30_mlp_down_proj_q_scale4, lv61_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv122 = R.call_tir(cls.fuse_add_norm_decode, (lv92_1, lv121, model_layers_31_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv123: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv122[1]
            rms_norm281: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv122[0]
            lv31_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_31_self_attn_c_attn_q_weight4, model_layers_31_self_attn_c_attn_q_scale4, rms_norm281, model_layers_31_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape556 = R.call_tir(cls.reshape, (lv31_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape557 = R.call_tir(cls.reshape1, (reshape556,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv701 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape557), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape558 = R.call_tir(cls.reshape2, (lv701,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape559 = R.call_tir(cls.reshape3, (reshape558,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv93_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_31_self_attn_o_proj_q_weight4, model_layers_31_self_attn_o_proj_q_scale4, reshape559), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv124 = R.call_tir(cls.fuse_add_norm_decode, (lv93_1, lv123, model_layers_31_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv125: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv124[1]
            rms_norm282: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv124[0]
            lv94_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_31_mlp_gate_up_proj_q_weight4, model_layers_31_mlp_gate_up_proj_q_scale4, rms_norm282), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv63_2 = R.call_tir(cls.fused_split_silu_multiply, (lv94_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv95_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_31_mlp_down_proj_q_weight4, model_layers_31_mlp_down_proj_q_scale4, lv63_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv126 = R.call_tir(cls.fuse_add_norm_decode, (lv95_1, lv125, model_layers_32_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv127: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv126[1]
            rms_norm283: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv126[0]
            lv32_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_32_self_attn_c_attn_q_weight4, model_layers_32_self_attn_c_attn_q_scale4, rms_norm283, model_layers_32_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape560 = R.call_tir(cls.reshape, (lv32_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape561 = R.call_tir(cls.reshape1, (reshape560,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv706 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape561), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape562 = R.call_tir(cls.reshape2, (lv706,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape563 = R.call_tir(cls.reshape3, (reshape562,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv96_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_32_self_attn_o_proj_q_weight4, model_layers_32_self_attn_o_proj_q_scale4, reshape563), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv128 = R.call_tir(cls.fuse_add_norm_decode, (lv96_1, lv127, model_layers_32_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv129: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv128[1]
            rms_norm284: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv128[0]
            lv97_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_32_mlp_gate_up_proj_q_weight4, model_layers_32_mlp_gate_up_proj_q_scale4, rms_norm284), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv65_2 = R.call_tir(cls.fused_split_silu_multiply, (lv97_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv98_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_32_mlp_down_proj_q_weight4, model_layers_32_mlp_down_proj_q_scale4, lv65_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv130 = R.call_tir(cls.fuse_add_norm_decode, (lv98_1, lv129, model_layers_33_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv131: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv130[1]
            rms_norm285: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv130[0]
            lv33_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_33_self_attn_c_attn_q_weight4, model_layers_33_self_attn_c_attn_q_scale4, rms_norm285, model_layers_33_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape564 = R.call_tir(cls.reshape, (lv33_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape565 = R.call_tir(cls.reshape1, (reshape564,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv711 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape565), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape566 = R.call_tir(cls.reshape2, (lv711,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape567 = R.call_tir(cls.reshape3, (reshape566,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv99_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_33_self_attn_o_proj_q_weight4, model_layers_33_self_attn_o_proj_q_scale4, reshape567), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv132 = R.call_tir(cls.fuse_add_norm_decode, (lv99_1, lv131, model_layers_33_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv133: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv132[1]
            rms_norm286: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv132[0]
            lv100_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_33_mlp_gate_up_proj_q_weight4, model_layers_33_mlp_gate_up_proj_q_scale4, rms_norm286), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv67_2 = R.call_tir(cls.fused_split_silu_multiply, (lv100_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv101_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_33_mlp_down_proj_q_weight4, model_layers_33_mlp_down_proj_q_scale4, lv67_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv134 = R.call_tir(cls.fuse_add_norm_decode, (lv101_1, lv133, model_layers_34_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv135: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv134[1]
            rms_norm287: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv134[0]
            lv34_2 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_34_self_attn_c_attn_q_weight4, model_layers_34_self_attn_c_attn_q_scale4, rms_norm287, model_layers_34_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape568 = R.call_tir(cls.reshape, (lv34_2,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape569 = R.call_tir(cls.reshape1, (reshape568,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv716 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape569), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape570 = R.call_tir(cls.reshape2, (lv716,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape571 = R.call_tir(cls.reshape3, (reshape570,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv102_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_34_self_attn_o_proj_q_weight4, model_layers_34_self_attn_o_proj_q_scale4, reshape571), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv136 = R.call_tir(cls.fuse_add_norm_decode, (lv102_1, lv135, model_layers_34_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv137: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv136[1]
            rms_norm288: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv136[0]
            lv103_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_34_mlp_gate_up_proj_q_weight4, model_layers_34_mlp_gate_up_proj_q_scale4, rms_norm288), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv69_2 = R.call_tir(cls.fused_split_silu_multiply, (lv103_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv104_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_34_mlp_down_proj_q_weight4, model_layers_34_mlp_down_proj_q_scale4, lv69_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv138 = R.call_tir(cls.fuse_add_norm_decode, (lv104_1, lv137, model_layers_35_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv139: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv138[1]
            rms_norm289: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv138[0]
            lv35_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_35_self_attn_c_attn_q_weight4, model_layers_35_self_attn_c_attn_q_scale4, rms_norm289, model_layers_35_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 2560), dtype="float16"))
            reshape572 = R.call_tir(cls.reshape, (lv35_3,), out_sinfo=R.Tensor((batch_size, 1, 20, 128), dtype="float16"))
            reshape573 = R.call_tir(cls.reshape1, (reshape572,), out_sinfo=R.Tensor((batch_size, 20, 128), dtype="float16"))
            lv721 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape573), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape574 = R.call_tir(cls.reshape2, (lv721,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape575 = R.call_tir(cls.reshape3, (reshape574,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv105_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_35_self_attn_o_proj_q_weight4, model_layers_35_self_attn_o_proj_q_scale4, reshape575), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv140 = R.call_tir(cls.fuse_add_norm_decode, (lv105_1, lv139, model_layers_35_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv141: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv140[1]
            rms_norm290: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv140[0]
            lv106_1 = R.call_tir(cls.fused_dequantize3_NT_matmul2, (model_layers_35_mlp_gate_up_proj_q_weight4, model_layers_35_mlp_gate_up_proj_q_scale4, rms_norm290), out_sinfo=R.Tensor((batch_size, 1, 22016), dtype="float16"))
            lv71_2 = R.call_tir(cls.fused_split_silu_multiply, (lv106_1,), out_sinfo=R.Tensor((batch_size, 1, 11008), dtype="float16"))
            lv107_1 = R.call_tir(cls.fused_dequantize4_NT_matmul3, (model_layers_35_mlp_down_proj_q_weight4, model_layers_35_mlp_down_proj_q_scale4, lv71_2), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv142 = R.call_tir(cls.fuse_add_norm_decode, (lv107_1, lv141, model_norm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            rms_norm291: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv142[0]
            lv108_1 = R.call_tir(cls.fused_dequantize_NT_matmul4, (model_embed_tokens_q_weight4, model_embed_tokens_q_scale4, rms_norm291), out_sinfo=R.Tensor((batch_size, 1, 151936), dtype="float32"))
            gv4: R.Tuple(R.Tensor((batch_size, 1, 151936), dtype="float32"), R.Object) = lv108_1, paged_kv_cache
            R.output(gv4)
        return gv4

    @R.function
    def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight3: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale3: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight3: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale3: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias3: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight3: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale3: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight3: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale3: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm146 = R.call_tir(cls.rms_norm1, (input_embeds, model_layers_0_input_layernorm_weight3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv36 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_0_self_attn_c_attn_q_weight3, model_layers_0_self_attn_c_attn_q_scale3, rms_norm146, model_layers_0_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape288 = R.call_tir(cls.reshape4, (lv36,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape289 = R.call_tir(cls.reshape5, (reshape288,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv365 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape289), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape290 = R.call_tir(cls.reshape6, (lv365,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape291 = R.call_tir(cls.reshape7, (reshape290,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv109 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_0_self_attn_o_proj_q_weight3, model_layers_0_self_attn_o_proj_q_scale3, reshape291), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv144 = R.call_tir(cls.fuse_add_norm_prefill, (lv109, input_embeds, model_layers_0_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv145: R.Tensor((1, seq_len, 2048), dtype="float16") = lv144[1]
            rms_norm147: R.Tensor((1, seq_len, 2048), dtype="float16") = lv144[0]
            lv110 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_0_mlp_gate_up_proj_q_weight3, model_layers_0_mlp_gate_up_proj_q_scale3, rms_norm147), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv73 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv110,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv111 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_0_mlp_down_proj_q_weight3, model_layers_0_mlp_down_proj_q_scale3, lv73), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv146 = R.call_tir(cls.fuse_add_norm_prefill, (lv111, lv145, model_layers_1_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv147: R.Tensor((1, seq_len, 2048), dtype="float16") = lv146[1]
            rms_norm148: R.Tensor((1, seq_len, 2048), dtype="float16") = lv146[0]
            lv37 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_1_self_attn_c_attn_q_weight3, model_layers_1_self_attn_c_attn_q_scale3, rms_norm148, model_layers_1_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape292 = R.call_tir(cls.reshape4, (lv37,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape293 = R.call_tir(cls.reshape5, (reshape292,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv370 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape293), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape294 = R.call_tir(cls.reshape6, (lv370,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape295 = R.call_tir(cls.reshape7, (reshape294,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv112 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_1_self_attn_o_proj_q_weight3, model_layers_1_self_attn_o_proj_q_scale3, reshape295), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv148 = R.call_tir(cls.fuse_add_norm_prefill, (lv112, lv147, model_layers_1_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv149: R.Tensor((1, seq_len, 2048), dtype="float16") = lv148[1]
            rms_norm149: R.Tensor((1, seq_len, 2048), dtype="float16") = lv148[0]
            lv113 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_1_mlp_gate_up_proj_q_weight3, model_layers_1_mlp_gate_up_proj_q_scale3, rms_norm149), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv75 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv113,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv114 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_1_mlp_down_proj_q_weight3, model_layers_1_mlp_down_proj_q_scale3, lv75), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv150 = R.call_tir(cls.fuse_add_norm_prefill, (lv114, lv149, model_layers_2_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv151: R.Tensor((1, seq_len, 2048), dtype="float16") = lv150[1]
            rms_norm150: R.Tensor((1, seq_len, 2048), dtype="float16") = lv150[0]
            lv38 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_2_self_attn_c_attn_q_weight3, model_layers_2_self_attn_c_attn_q_scale3, rms_norm150, model_layers_2_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape296 = R.call_tir(cls.reshape4, (lv38,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape297 = R.call_tir(cls.reshape5, (reshape296,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv375 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape297), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape298 = R.call_tir(cls.reshape6, (lv375,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape299 = R.call_tir(cls.reshape7, (reshape298,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv115 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_2_self_attn_o_proj_q_weight3, model_layers_2_self_attn_o_proj_q_scale3, reshape299), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv152 = R.call_tir(cls.fuse_add_norm_prefill, (lv115, lv151, model_layers_2_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv153: R.Tensor((1, seq_len, 2048), dtype="float16") = lv152[1]
            rms_norm151: R.Tensor((1, seq_len, 2048), dtype="float16") = lv152[0]
            lv116 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_2_mlp_gate_up_proj_q_weight3, model_layers_2_mlp_gate_up_proj_q_scale3, rms_norm151), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv77 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv116,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv117 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_2_mlp_down_proj_q_weight3, model_layers_2_mlp_down_proj_q_scale3, lv77), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv154 = R.call_tir(cls.fuse_add_norm_prefill, (lv117, lv153, model_layers_3_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv155: R.Tensor((1, seq_len, 2048), dtype="float16") = lv154[1]
            rms_norm152: R.Tensor((1, seq_len, 2048), dtype="float16") = lv154[0]
            lv39 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_3_self_attn_c_attn_q_weight3, model_layers_3_self_attn_c_attn_q_scale3, rms_norm152, model_layers_3_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape300 = R.call_tir(cls.reshape4, (lv39,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape301 = R.call_tir(cls.reshape5, (reshape300,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv380 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape301), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape302 = R.call_tir(cls.reshape6, (lv380,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape303 = R.call_tir(cls.reshape7, (reshape302,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv118 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_3_self_attn_o_proj_q_weight3, model_layers_3_self_attn_o_proj_q_scale3, reshape303), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv156 = R.call_tir(cls.fuse_add_norm_prefill, (lv118, lv155, model_layers_3_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv157: R.Tensor((1, seq_len, 2048), dtype="float16") = lv156[1]
            rms_norm153: R.Tensor((1, seq_len, 2048), dtype="float16") = lv156[0]
            lv119 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_3_mlp_gate_up_proj_q_weight3, model_layers_3_mlp_gate_up_proj_q_scale3, rms_norm153), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv79 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv119,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv120 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_3_mlp_down_proj_q_weight3, model_layers_3_mlp_down_proj_q_scale3, lv79), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv158 = R.call_tir(cls.fuse_add_norm_prefill, (lv120, lv157, model_layers_4_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv159: R.Tensor((1, seq_len, 2048), dtype="float16") = lv158[1]
            rms_norm154: R.Tensor((1, seq_len, 2048), dtype="float16") = lv158[0]
            lv40 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_4_self_attn_c_attn_q_weight3, model_layers_4_self_attn_c_attn_q_scale3, rms_norm154, model_layers_4_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape304 = R.call_tir(cls.reshape4, (lv40,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape305 = R.call_tir(cls.reshape5, (reshape304,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv385 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape305), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape306 = R.call_tir(cls.reshape6, (lv385,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape307 = R.call_tir(cls.reshape7, (reshape306,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv121 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_4_self_attn_o_proj_q_weight3, model_layers_4_self_attn_o_proj_q_scale3, reshape307), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv160 = R.call_tir(cls.fuse_add_norm_prefill, (lv121, lv159, model_layers_4_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv161: R.Tensor((1, seq_len, 2048), dtype="float16") = lv160[1]
            rms_norm155: R.Tensor((1, seq_len, 2048), dtype="float16") = lv160[0]
            lv122 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_4_mlp_gate_up_proj_q_weight3, model_layers_4_mlp_gate_up_proj_q_scale3, rms_norm155), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv81 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv122,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv123 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_4_mlp_down_proj_q_weight3, model_layers_4_mlp_down_proj_q_scale3, lv81), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv162 = R.call_tir(cls.fuse_add_norm_prefill, (lv123, lv161, model_layers_5_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv163: R.Tensor((1, seq_len, 2048), dtype="float16") = lv162[1]
            rms_norm156: R.Tensor((1, seq_len, 2048), dtype="float16") = lv162[0]
            lv41 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_5_self_attn_c_attn_q_weight3, model_layers_5_self_attn_c_attn_q_scale3, rms_norm156, model_layers_5_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape308 = R.call_tir(cls.reshape4, (lv41,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape309 = R.call_tir(cls.reshape5, (reshape308,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv390 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape309), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape310 = R.call_tir(cls.reshape6, (lv390,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape311 = R.call_tir(cls.reshape7, (reshape310,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv124 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_5_self_attn_o_proj_q_weight3, model_layers_5_self_attn_o_proj_q_scale3, reshape311), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv164 = R.call_tir(cls.fuse_add_norm_prefill, (lv124, lv163, model_layers_5_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv165: R.Tensor((1, seq_len, 2048), dtype="float16") = lv164[1]
            rms_norm157: R.Tensor((1, seq_len, 2048), dtype="float16") = lv164[0]
            lv125 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_5_mlp_gate_up_proj_q_weight3, model_layers_5_mlp_gate_up_proj_q_scale3, rms_norm157), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv83 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv125,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv126 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_5_mlp_down_proj_q_weight3, model_layers_5_mlp_down_proj_q_scale3, lv83), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv166 = R.call_tir(cls.fuse_add_norm_prefill, (lv126, lv165, model_layers_6_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv167: R.Tensor((1, seq_len, 2048), dtype="float16") = lv166[1]
            rms_norm158: R.Tensor((1, seq_len, 2048), dtype="float16") = lv166[0]
            lv42 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_6_self_attn_c_attn_q_weight3, model_layers_6_self_attn_c_attn_q_scale3, rms_norm158, model_layers_6_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape312 = R.call_tir(cls.reshape4, (lv42,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape313 = R.call_tir(cls.reshape5, (reshape312,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv395 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape313), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape314 = R.call_tir(cls.reshape6, (lv395,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape315 = R.call_tir(cls.reshape7, (reshape314,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv127 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_6_self_attn_o_proj_q_weight3, model_layers_6_self_attn_o_proj_q_scale3, reshape315), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv168 = R.call_tir(cls.fuse_add_norm_prefill, (lv127, lv167, model_layers_6_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv169: R.Tensor((1, seq_len, 2048), dtype="float16") = lv168[1]
            rms_norm159: R.Tensor((1, seq_len, 2048), dtype="float16") = lv168[0]
            lv128 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_6_mlp_gate_up_proj_q_weight3, model_layers_6_mlp_gate_up_proj_q_scale3, rms_norm159), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv85 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv128,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv129 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_6_mlp_down_proj_q_weight3, model_layers_6_mlp_down_proj_q_scale3, lv85), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv170 = R.call_tir(cls.fuse_add_norm_prefill, (lv129, lv169, model_layers_7_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv171: R.Tensor((1, seq_len, 2048), dtype="float16") = lv170[1]
            rms_norm160: R.Tensor((1, seq_len, 2048), dtype="float16") = lv170[0]
            lv43 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_7_self_attn_c_attn_q_weight3, model_layers_7_self_attn_c_attn_q_scale3, rms_norm160, model_layers_7_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape316 = R.call_tir(cls.reshape4, (lv43,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape317 = R.call_tir(cls.reshape5, (reshape316,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv400 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape317), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape318 = R.call_tir(cls.reshape6, (lv400,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape319 = R.call_tir(cls.reshape7, (reshape318,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv130 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_7_self_attn_o_proj_q_weight3, model_layers_7_self_attn_o_proj_q_scale3, reshape319), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv172 = R.call_tir(cls.fuse_add_norm_prefill, (lv130, lv171, model_layers_7_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv173: R.Tensor((1, seq_len, 2048), dtype="float16") = lv172[1]
            rms_norm161: R.Tensor((1, seq_len, 2048), dtype="float16") = lv172[0]
            lv131 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_7_mlp_gate_up_proj_q_weight3, model_layers_7_mlp_gate_up_proj_q_scale3, rms_norm161), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv87 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv131,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv132 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_7_mlp_down_proj_q_weight3, model_layers_7_mlp_down_proj_q_scale3, lv87), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv174 = R.call_tir(cls.fuse_add_norm_prefill, (lv132, lv173, model_layers_8_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv175: R.Tensor((1, seq_len, 2048), dtype="float16") = lv174[1]
            rms_norm162: R.Tensor((1, seq_len, 2048), dtype="float16") = lv174[0]
            lv44 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_8_self_attn_c_attn_q_weight3, model_layers_8_self_attn_c_attn_q_scale3, rms_norm162, model_layers_8_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape320 = R.call_tir(cls.reshape4, (lv44,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape321 = R.call_tir(cls.reshape5, (reshape320,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv405 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape321), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape322 = R.call_tir(cls.reshape6, (lv405,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape323 = R.call_tir(cls.reshape7, (reshape322,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv133 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_8_self_attn_o_proj_q_weight3, model_layers_8_self_attn_o_proj_q_scale3, reshape323), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv176 = R.call_tir(cls.fuse_add_norm_prefill, (lv133, lv175, model_layers_8_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv177: R.Tensor((1, seq_len, 2048), dtype="float16") = lv176[1]
            rms_norm163: R.Tensor((1, seq_len, 2048), dtype="float16") = lv176[0]
            lv134 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_8_mlp_gate_up_proj_q_weight3, model_layers_8_mlp_gate_up_proj_q_scale3, rms_norm163), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv89 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv134,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv135 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_8_mlp_down_proj_q_weight3, model_layers_8_mlp_down_proj_q_scale3, lv89), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv178 = R.call_tir(cls.fuse_add_norm_prefill, (lv135, lv177, model_layers_9_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv179: R.Tensor((1, seq_len, 2048), dtype="float16") = lv178[1]
            rms_norm164: R.Tensor((1, seq_len, 2048), dtype="float16") = lv178[0]
            lv45 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_9_self_attn_c_attn_q_weight3, model_layers_9_self_attn_c_attn_q_scale3, rms_norm164, model_layers_9_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape324 = R.call_tir(cls.reshape4, (lv45,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape325 = R.call_tir(cls.reshape5, (reshape324,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv410 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape325), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape326 = R.call_tir(cls.reshape6, (lv410,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape327 = R.call_tir(cls.reshape7, (reshape326,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv136 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_9_self_attn_o_proj_q_weight3, model_layers_9_self_attn_o_proj_q_scale3, reshape327), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv180 = R.call_tir(cls.fuse_add_norm_prefill, (lv136, lv179, model_layers_9_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv181: R.Tensor((1, seq_len, 2048), dtype="float16") = lv180[1]
            rms_norm165: R.Tensor((1, seq_len, 2048), dtype="float16") = lv180[0]
            lv137 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_9_mlp_gate_up_proj_q_weight3, model_layers_9_mlp_gate_up_proj_q_scale3, rms_norm165), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv91 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv137,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv138 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_9_mlp_down_proj_q_weight3, model_layers_9_mlp_down_proj_q_scale3, lv91), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv182 = R.call_tir(cls.fuse_add_norm_prefill, (lv138, lv181, model_layers_10_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv183: R.Tensor((1, seq_len, 2048), dtype="float16") = lv182[1]
            rms_norm166: R.Tensor((1, seq_len, 2048), dtype="float16") = lv182[0]
            lv46 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_10_self_attn_c_attn_q_weight3, model_layers_10_self_attn_c_attn_q_scale3, rms_norm166, model_layers_10_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape328 = R.call_tir(cls.reshape4, (lv46,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape329 = R.call_tir(cls.reshape5, (reshape328,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv415 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape329), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape330 = R.call_tir(cls.reshape6, (lv415,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape331 = R.call_tir(cls.reshape7, (reshape330,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv139 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_10_self_attn_o_proj_q_weight3, model_layers_10_self_attn_o_proj_q_scale3, reshape331), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv184 = R.call_tir(cls.fuse_add_norm_prefill, (lv139, lv183, model_layers_10_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv185: R.Tensor((1, seq_len, 2048), dtype="float16") = lv184[1]
            rms_norm167: R.Tensor((1, seq_len, 2048), dtype="float16") = lv184[0]
            lv140 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_10_mlp_gate_up_proj_q_weight3, model_layers_10_mlp_gate_up_proj_q_scale3, rms_norm167), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv93 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv140,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv141 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_10_mlp_down_proj_q_weight3, model_layers_10_mlp_down_proj_q_scale3, lv93), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv186 = R.call_tir(cls.fuse_add_norm_prefill, (lv141, lv185, model_layers_11_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv187: R.Tensor((1, seq_len, 2048), dtype="float16") = lv186[1]
            rms_norm168: R.Tensor((1, seq_len, 2048), dtype="float16") = lv186[0]
            lv47 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_11_self_attn_c_attn_q_weight3, model_layers_11_self_attn_c_attn_q_scale3, rms_norm168, model_layers_11_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape332 = R.call_tir(cls.reshape4, (lv47,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape333 = R.call_tir(cls.reshape5, (reshape332,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv420 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape333), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape334 = R.call_tir(cls.reshape6, (lv420,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape335 = R.call_tir(cls.reshape7, (reshape334,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv142 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_11_self_attn_o_proj_q_weight3, model_layers_11_self_attn_o_proj_q_scale3, reshape335), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv188 = R.call_tir(cls.fuse_add_norm_prefill, (lv142, lv187, model_layers_11_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv189: R.Tensor((1, seq_len, 2048), dtype="float16") = lv188[1]
            rms_norm169: R.Tensor((1, seq_len, 2048), dtype="float16") = lv188[0]
            lv143 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_11_mlp_gate_up_proj_q_weight3, model_layers_11_mlp_gate_up_proj_q_scale3, rms_norm169), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv95 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv143,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv144_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_11_mlp_down_proj_q_weight3, model_layers_11_mlp_down_proj_q_scale3, lv95), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv190 = R.call_tir(cls.fuse_add_norm_prefill, (lv144_1, lv189, model_layers_12_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv191: R.Tensor((1, seq_len, 2048), dtype="float16") = lv190[1]
            rms_norm170: R.Tensor((1, seq_len, 2048), dtype="float16") = lv190[0]
            lv48 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_12_self_attn_c_attn_q_weight3, model_layers_12_self_attn_c_attn_q_scale3, rms_norm170, model_layers_12_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape336 = R.call_tir(cls.reshape4, (lv48,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape337 = R.call_tir(cls.reshape5, (reshape336,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv425 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape337), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape338 = R.call_tir(cls.reshape6, (lv425,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape339 = R.call_tir(cls.reshape7, (reshape338,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv145_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_12_self_attn_o_proj_q_weight3, model_layers_12_self_attn_o_proj_q_scale3, reshape339), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv192 = R.call_tir(cls.fuse_add_norm_prefill, (lv145_1, lv191, model_layers_12_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv193: R.Tensor((1, seq_len, 2048), dtype="float16") = lv192[1]
            rms_norm171: R.Tensor((1, seq_len, 2048), dtype="float16") = lv192[0]
            lv146_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_12_mlp_gate_up_proj_q_weight3, model_layers_12_mlp_gate_up_proj_q_scale3, rms_norm171), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv97 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv146_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv147_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_12_mlp_down_proj_q_weight3, model_layers_12_mlp_down_proj_q_scale3, lv97), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv194 = R.call_tir(cls.fuse_add_norm_prefill, (lv147_1, lv193, model_layers_13_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv195: R.Tensor((1, seq_len, 2048), dtype="float16") = lv194[1]
            rms_norm172: R.Tensor((1, seq_len, 2048), dtype="float16") = lv194[0]
            lv49 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_13_self_attn_c_attn_q_weight3, model_layers_13_self_attn_c_attn_q_scale3, rms_norm172, model_layers_13_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape340 = R.call_tir(cls.reshape4, (lv49,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape341 = R.call_tir(cls.reshape5, (reshape340,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv430 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape341), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape342 = R.call_tir(cls.reshape6, (lv430,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape343 = R.call_tir(cls.reshape7, (reshape342,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv148_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_13_self_attn_o_proj_q_weight3, model_layers_13_self_attn_o_proj_q_scale3, reshape343), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv196 = R.call_tir(cls.fuse_add_norm_prefill, (lv148_1, lv195, model_layers_13_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv197: R.Tensor((1, seq_len, 2048), dtype="float16") = lv196[1]
            rms_norm173: R.Tensor((1, seq_len, 2048), dtype="float16") = lv196[0]
            lv149_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_13_mlp_gate_up_proj_q_weight3, model_layers_13_mlp_gate_up_proj_q_scale3, rms_norm173), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv99 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv149_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv150_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_13_mlp_down_proj_q_weight3, model_layers_13_mlp_down_proj_q_scale3, lv99), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv198 = R.call_tir(cls.fuse_add_norm_prefill, (lv150_1, lv197, model_layers_14_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv199: R.Tensor((1, seq_len, 2048), dtype="float16") = lv198[1]
            rms_norm174: R.Tensor((1, seq_len, 2048), dtype="float16") = lv198[0]
            lv50 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_14_self_attn_c_attn_q_weight3, model_layers_14_self_attn_c_attn_q_scale3, rms_norm174, model_layers_14_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape344 = R.call_tir(cls.reshape4, (lv50,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape345 = R.call_tir(cls.reshape5, (reshape344,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv435 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape345), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape346 = R.call_tir(cls.reshape6, (lv435,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape347 = R.call_tir(cls.reshape7, (reshape346,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv151_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_14_self_attn_o_proj_q_weight3, model_layers_14_self_attn_o_proj_q_scale3, reshape347), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv200 = R.call_tir(cls.fuse_add_norm_prefill, (lv151_1, lv199, model_layers_14_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv201: R.Tensor((1, seq_len, 2048), dtype="float16") = lv200[1]
            rms_norm175: R.Tensor((1, seq_len, 2048), dtype="float16") = lv200[0]
            lv152_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_14_mlp_gate_up_proj_q_weight3, model_layers_14_mlp_gate_up_proj_q_scale3, rms_norm175), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv101 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv152_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv153_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_14_mlp_down_proj_q_weight3, model_layers_14_mlp_down_proj_q_scale3, lv101), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv202 = R.call_tir(cls.fuse_add_norm_prefill, (lv153_1, lv201, model_layers_15_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv203: R.Tensor((1, seq_len, 2048), dtype="float16") = lv202[1]
            rms_norm176: R.Tensor((1, seq_len, 2048), dtype="float16") = lv202[0]
            lv51 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_15_self_attn_c_attn_q_weight3, model_layers_15_self_attn_c_attn_q_scale3, rms_norm176, model_layers_15_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape348 = R.call_tir(cls.reshape4, (lv51,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape349 = R.call_tir(cls.reshape5, (reshape348,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv440 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape349), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape350 = R.call_tir(cls.reshape6, (lv440,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape351 = R.call_tir(cls.reshape7, (reshape350,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv154_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_15_self_attn_o_proj_q_weight3, model_layers_15_self_attn_o_proj_q_scale3, reshape351), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv204 = R.call_tir(cls.fuse_add_norm_prefill, (lv154_1, lv203, model_layers_15_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv205: R.Tensor((1, seq_len, 2048), dtype="float16") = lv204[1]
            rms_norm177: R.Tensor((1, seq_len, 2048), dtype="float16") = lv204[0]
            lv155_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_15_mlp_gate_up_proj_q_weight3, model_layers_15_mlp_gate_up_proj_q_scale3, rms_norm177), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv103 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv155_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv156_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_15_mlp_down_proj_q_weight3, model_layers_15_mlp_down_proj_q_scale3, lv103), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv206 = R.call_tir(cls.fuse_add_norm_prefill, (lv156_1, lv205, model_layers_16_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv207: R.Tensor((1, seq_len, 2048), dtype="float16") = lv206[1]
            rms_norm178: R.Tensor((1, seq_len, 2048), dtype="float16") = lv206[0]
            lv52 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_16_self_attn_c_attn_q_weight3, model_layers_16_self_attn_c_attn_q_scale3, rms_norm178, model_layers_16_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape352 = R.call_tir(cls.reshape4, (lv52,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape353 = R.call_tir(cls.reshape5, (reshape352,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv445 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape353), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape354 = R.call_tir(cls.reshape6, (lv445,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape355 = R.call_tir(cls.reshape7, (reshape354,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv157_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_16_self_attn_o_proj_q_weight3, model_layers_16_self_attn_o_proj_q_scale3, reshape355), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv208 = R.call_tir(cls.fuse_add_norm_prefill, (lv157_1, lv207, model_layers_16_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv209: R.Tensor((1, seq_len, 2048), dtype="float16") = lv208[1]
            rms_norm179: R.Tensor((1, seq_len, 2048), dtype="float16") = lv208[0]
            lv158_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_16_mlp_gate_up_proj_q_weight3, model_layers_16_mlp_gate_up_proj_q_scale3, rms_norm179), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv105 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv158_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv159_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_16_mlp_down_proj_q_weight3, model_layers_16_mlp_down_proj_q_scale3, lv105), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv210 = R.call_tir(cls.fuse_add_norm_prefill, (lv159_1, lv209, model_layers_17_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv211: R.Tensor((1, seq_len, 2048), dtype="float16") = lv210[1]
            rms_norm180: R.Tensor((1, seq_len, 2048), dtype="float16") = lv210[0]
            lv53 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_17_self_attn_c_attn_q_weight3, model_layers_17_self_attn_c_attn_q_scale3, rms_norm180, model_layers_17_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape356 = R.call_tir(cls.reshape4, (lv53,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape357 = R.call_tir(cls.reshape5, (reshape356,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv450 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape357), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape358 = R.call_tir(cls.reshape6, (lv450,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape359 = R.call_tir(cls.reshape7, (reshape358,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv160_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_17_self_attn_o_proj_q_weight3, model_layers_17_self_attn_o_proj_q_scale3, reshape359), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv212 = R.call_tir(cls.fuse_add_norm_prefill, (lv160_1, lv211, model_layers_17_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv213: R.Tensor((1, seq_len, 2048), dtype="float16") = lv212[1]
            rms_norm181: R.Tensor((1, seq_len, 2048), dtype="float16") = lv212[0]
            lv161_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_17_mlp_gate_up_proj_q_weight3, model_layers_17_mlp_gate_up_proj_q_scale3, rms_norm181), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv107 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv161_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv162_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_17_mlp_down_proj_q_weight3, model_layers_17_mlp_down_proj_q_scale3, lv107), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv214 = R.call_tir(cls.fuse_add_norm_prefill, (lv162_1, lv213, model_layers_18_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv215: R.Tensor((1, seq_len, 2048), dtype="float16") = lv214[1]
            rms_norm182: R.Tensor((1, seq_len, 2048), dtype="float16") = lv214[0]
            lv54 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_18_self_attn_c_attn_q_weight3, model_layers_18_self_attn_c_attn_q_scale3, rms_norm182, model_layers_18_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape360 = R.call_tir(cls.reshape4, (lv54,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape361 = R.call_tir(cls.reshape5, (reshape360,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv455 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape361), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape362 = R.call_tir(cls.reshape6, (lv455,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape363 = R.call_tir(cls.reshape7, (reshape362,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv163_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_18_self_attn_o_proj_q_weight3, model_layers_18_self_attn_o_proj_q_scale3, reshape363), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv216 = R.call_tir(cls.fuse_add_norm_prefill, (lv163_1, lv215, model_layers_18_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv217: R.Tensor((1, seq_len, 2048), dtype="float16") = lv216[1]
            rms_norm183: R.Tensor((1, seq_len, 2048), dtype="float16") = lv216[0]
            lv164_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_18_mlp_gate_up_proj_q_weight3, model_layers_18_mlp_gate_up_proj_q_scale3, rms_norm183), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv109_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv164_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv165_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_18_mlp_down_proj_q_weight3, model_layers_18_mlp_down_proj_q_scale3, lv109_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv218 = R.call_tir(cls.fuse_add_norm_prefill, (lv165_1, lv217, model_layers_19_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv219: R.Tensor((1, seq_len, 2048), dtype="float16") = lv218[1]
            rms_norm184: R.Tensor((1, seq_len, 2048), dtype="float16") = lv218[0]
            lv55 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_19_self_attn_c_attn_q_weight3, model_layers_19_self_attn_c_attn_q_scale3, rms_norm184, model_layers_19_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape364 = R.call_tir(cls.reshape4, (lv55,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape365 = R.call_tir(cls.reshape5, (reshape364,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv460 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape365), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape366 = R.call_tir(cls.reshape6, (lv460,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape367 = R.call_tir(cls.reshape7, (reshape366,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv166_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_19_self_attn_o_proj_q_weight3, model_layers_19_self_attn_o_proj_q_scale3, reshape367), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv220 = R.call_tir(cls.fuse_add_norm_prefill, (lv166_1, lv219, model_layers_19_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv221: R.Tensor((1, seq_len, 2048), dtype="float16") = lv220[1]
            rms_norm185: R.Tensor((1, seq_len, 2048), dtype="float16") = lv220[0]
            lv167_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_19_mlp_gate_up_proj_q_weight3, model_layers_19_mlp_gate_up_proj_q_scale3, rms_norm185), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv111_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv167_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv168_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_19_mlp_down_proj_q_weight3, model_layers_19_mlp_down_proj_q_scale3, lv111_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv222 = R.call_tir(cls.fuse_add_norm_prefill, (lv168_1, lv221, model_layers_20_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv223: R.Tensor((1, seq_len, 2048), dtype="float16") = lv222[1]
            rms_norm186: R.Tensor((1, seq_len, 2048), dtype="float16") = lv222[0]
            lv56 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_20_self_attn_c_attn_q_weight3, model_layers_20_self_attn_c_attn_q_scale3, rms_norm186, model_layers_20_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape368 = R.call_tir(cls.reshape4, (lv56,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape369 = R.call_tir(cls.reshape5, (reshape368,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv465 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape369), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape370 = R.call_tir(cls.reshape6, (lv465,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape371 = R.call_tir(cls.reshape7, (reshape370,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv169_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_20_self_attn_o_proj_q_weight3, model_layers_20_self_attn_o_proj_q_scale3, reshape371), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv224 = R.call_tir(cls.fuse_add_norm_prefill, (lv169_1, lv223, model_layers_20_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv225: R.Tensor((1, seq_len, 2048), dtype="float16") = lv224[1]
            rms_norm187: R.Tensor((1, seq_len, 2048), dtype="float16") = lv224[0]
            lv170_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_20_mlp_gate_up_proj_q_weight3, model_layers_20_mlp_gate_up_proj_q_scale3, rms_norm187), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv113_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv170_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv171_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_20_mlp_down_proj_q_weight3, model_layers_20_mlp_down_proj_q_scale3, lv113_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv226 = R.call_tir(cls.fuse_add_norm_prefill, (lv171_1, lv225, model_layers_21_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv227: R.Tensor((1, seq_len, 2048), dtype="float16") = lv226[1]
            rms_norm188: R.Tensor((1, seq_len, 2048), dtype="float16") = lv226[0]
            lv57 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_21_self_attn_c_attn_q_weight3, model_layers_21_self_attn_c_attn_q_scale3, rms_norm188, model_layers_21_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape372 = R.call_tir(cls.reshape4, (lv57,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape373 = R.call_tir(cls.reshape5, (reshape372,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv470 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape373), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape374 = R.call_tir(cls.reshape6, (lv470,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape375 = R.call_tir(cls.reshape7, (reshape374,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv172_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_21_self_attn_o_proj_q_weight3, model_layers_21_self_attn_o_proj_q_scale3, reshape375), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv228 = R.call_tir(cls.fuse_add_norm_prefill, (lv172_1, lv227, model_layers_21_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv229: R.Tensor((1, seq_len, 2048), dtype="float16") = lv228[1]
            rms_norm189: R.Tensor((1, seq_len, 2048), dtype="float16") = lv228[0]
            lv173_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_21_mlp_gate_up_proj_q_weight3, model_layers_21_mlp_gate_up_proj_q_scale3, rms_norm189), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv115_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv173_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv174_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_21_mlp_down_proj_q_weight3, model_layers_21_mlp_down_proj_q_scale3, lv115_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv230 = R.call_tir(cls.fuse_add_norm_prefill, (lv174_1, lv229, model_layers_22_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv231: R.Tensor((1, seq_len, 2048), dtype="float16") = lv230[1]
            rms_norm190: R.Tensor((1, seq_len, 2048), dtype="float16") = lv230[0]
            lv58 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_22_self_attn_c_attn_q_weight3, model_layers_22_self_attn_c_attn_q_scale3, rms_norm190, model_layers_22_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape376 = R.call_tir(cls.reshape4, (lv58,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape377 = R.call_tir(cls.reshape5, (reshape376,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv475 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape377), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape378 = R.call_tir(cls.reshape6, (lv475,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape379 = R.call_tir(cls.reshape7, (reshape378,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv175_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_22_self_attn_o_proj_q_weight3, model_layers_22_self_attn_o_proj_q_scale3, reshape379), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv232 = R.call_tir(cls.fuse_add_norm_prefill, (lv175_1, lv231, model_layers_22_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv233: R.Tensor((1, seq_len, 2048), dtype="float16") = lv232[1]
            rms_norm191: R.Tensor((1, seq_len, 2048), dtype="float16") = lv232[0]
            lv176_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_22_mlp_gate_up_proj_q_weight3, model_layers_22_mlp_gate_up_proj_q_scale3, rms_norm191), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv117_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv176_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv177_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_22_mlp_down_proj_q_weight3, model_layers_22_mlp_down_proj_q_scale3, lv117_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv234 = R.call_tir(cls.fuse_add_norm_prefill, (lv177_1, lv233, model_layers_23_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv235: R.Tensor((1, seq_len, 2048), dtype="float16") = lv234[1]
            rms_norm192: R.Tensor((1, seq_len, 2048), dtype="float16") = lv234[0]
            lv59 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_23_self_attn_c_attn_q_weight3, model_layers_23_self_attn_c_attn_q_scale3, rms_norm192, model_layers_23_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape380 = R.call_tir(cls.reshape4, (lv59,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape381 = R.call_tir(cls.reshape5, (reshape380,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv480 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape381), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape382 = R.call_tir(cls.reshape6, (lv480,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape383 = R.call_tir(cls.reshape7, (reshape382,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv178_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_23_self_attn_o_proj_q_weight3, model_layers_23_self_attn_o_proj_q_scale3, reshape383), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv236 = R.call_tir(cls.fuse_add_norm_prefill, (lv178_1, lv235, model_layers_23_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv237: R.Tensor((1, seq_len, 2048), dtype="float16") = lv236[1]
            rms_norm193: R.Tensor((1, seq_len, 2048), dtype="float16") = lv236[0]
            lv179_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_23_mlp_gate_up_proj_q_weight3, model_layers_23_mlp_gate_up_proj_q_scale3, rms_norm193), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv119_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv179_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv180_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_23_mlp_down_proj_q_weight3, model_layers_23_mlp_down_proj_q_scale3, lv119_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv238 = R.call_tir(cls.fuse_add_norm_prefill, (lv180_1, lv237, model_layers_24_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv239: R.Tensor((1, seq_len, 2048), dtype="float16") = lv238[1]
            rms_norm194: R.Tensor((1, seq_len, 2048), dtype="float16") = lv238[0]
            lv60 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_24_self_attn_c_attn_q_weight3, model_layers_24_self_attn_c_attn_q_scale3, rms_norm194, model_layers_24_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape384 = R.call_tir(cls.reshape4, (lv60,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape385 = R.call_tir(cls.reshape5, (reshape384,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv485 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape385), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape386 = R.call_tir(cls.reshape6, (lv485,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape387 = R.call_tir(cls.reshape7, (reshape386,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv181_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_24_self_attn_o_proj_q_weight3, model_layers_24_self_attn_o_proj_q_scale3, reshape387), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv240 = R.call_tir(cls.fuse_add_norm_prefill, (lv181_1, lv239, model_layers_24_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv241: R.Tensor((1, seq_len, 2048), dtype="float16") = lv240[1]
            rms_norm195: R.Tensor((1, seq_len, 2048), dtype="float16") = lv240[0]
            lv182_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_24_mlp_gate_up_proj_q_weight3, model_layers_24_mlp_gate_up_proj_q_scale3, rms_norm195), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv121_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv182_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv183_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_24_mlp_down_proj_q_weight3, model_layers_24_mlp_down_proj_q_scale3, lv121_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv242 = R.call_tir(cls.fuse_add_norm_prefill, (lv183_1, lv241, model_layers_25_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv243: R.Tensor((1, seq_len, 2048), dtype="float16") = lv242[1]
            rms_norm196: R.Tensor((1, seq_len, 2048), dtype="float16") = lv242[0]
            lv61 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_25_self_attn_c_attn_q_weight3, model_layers_25_self_attn_c_attn_q_scale3, rms_norm196, model_layers_25_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape388 = R.call_tir(cls.reshape4, (lv61,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape389 = R.call_tir(cls.reshape5, (reshape388,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv490 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape389), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape390 = R.call_tir(cls.reshape6, (lv490,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape391 = R.call_tir(cls.reshape7, (reshape390,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv184_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_25_self_attn_o_proj_q_weight3, model_layers_25_self_attn_o_proj_q_scale3, reshape391), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv244 = R.call_tir(cls.fuse_add_norm_prefill, (lv184_1, lv243, model_layers_25_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv245: R.Tensor((1, seq_len, 2048), dtype="float16") = lv244[1]
            rms_norm197: R.Tensor((1, seq_len, 2048), dtype="float16") = lv244[0]
            lv185_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_25_mlp_gate_up_proj_q_weight3, model_layers_25_mlp_gate_up_proj_q_scale3, rms_norm197), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv123_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv185_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv186_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_25_mlp_down_proj_q_weight3, model_layers_25_mlp_down_proj_q_scale3, lv123_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv246 = R.call_tir(cls.fuse_add_norm_prefill, (lv186_1, lv245, model_layers_26_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv247: R.Tensor((1, seq_len, 2048), dtype="float16") = lv246[1]
            rms_norm198: R.Tensor((1, seq_len, 2048), dtype="float16") = lv246[0]
            lv62 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_26_self_attn_c_attn_q_weight3, model_layers_26_self_attn_c_attn_q_scale3, rms_norm198, model_layers_26_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape392 = R.call_tir(cls.reshape4, (lv62,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape393 = R.call_tir(cls.reshape5, (reshape392,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv495 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape393), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape394 = R.call_tir(cls.reshape6, (lv495,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape395 = R.call_tir(cls.reshape7, (reshape394,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv187_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_26_self_attn_o_proj_q_weight3, model_layers_26_self_attn_o_proj_q_scale3, reshape395), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv248 = R.call_tir(cls.fuse_add_norm_prefill, (lv187_1, lv247, model_layers_26_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv249: R.Tensor((1, seq_len, 2048), dtype="float16") = lv248[1]
            rms_norm199: R.Tensor((1, seq_len, 2048), dtype="float16") = lv248[0]
            lv188_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_26_mlp_gate_up_proj_q_weight3, model_layers_26_mlp_gate_up_proj_q_scale3, rms_norm199), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv125_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv188_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv189_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_26_mlp_down_proj_q_weight3, model_layers_26_mlp_down_proj_q_scale3, lv125_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv250 = R.call_tir(cls.fuse_add_norm_prefill, (lv189_1, lv249, model_layers_27_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv251: R.Tensor((1, seq_len, 2048), dtype="float16") = lv250[1]
            rms_norm200: R.Tensor((1, seq_len, 2048), dtype="float16") = lv250[0]
            lv63 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_27_self_attn_c_attn_q_weight3, model_layers_27_self_attn_c_attn_q_scale3, rms_norm200, model_layers_27_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape396 = R.call_tir(cls.reshape4, (lv63,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape397 = R.call_tir(cls.reshape5, (reshape396,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv500 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape397), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape398 = R.call_tir(cls.reshape6, (lv500,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape399 = R.call_tir(cls.reshape7, (reshape398,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv190_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_27_self_attn_o_proj_q_weight3, model_layers_27_self_attn_o_proj_q_scale3, reshape399), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv252 = R.call_tir(cls.fuse_add_norm_prefill, (lv190_1, lv251, model_layers_27_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv253: R.Tensor((1, seq_len, 2048), dtype="float16") = lv252[1]
            rms_norm201: R.Tensor((1, seq_len, 2048), dtype="float16") = lv252[0]
            lv191_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_27_mlp_gate_up_proj_q_weight3, model_layers_27_mlp_gate_up_proj_q_scale3, rms_norm201), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv127_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv191_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv192_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_27_mlp_down_proj_q_weight3, model_layers_27_mlp_down_proj_q_scale3, lv127_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv254 = R.call_tir(cls.fuse_add_norm_prefill, (lv192_1, lv253, model_layers_28_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv255: R.Tensor((1, seq_len, 2048), dtype="float16") = lv254[1]
            rms_norm202: R.Tensor((1, seq_len, 2048), dtype="float16") = lv254[0]
            lv64 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_28_self_attn_c_attn_q_weight3, model_layers_28_self_attn_c_attn_q_scale3, rms_norm202, model_layers_28_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape400 = R.call_tir(cls.reshape4, (lv64,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape401 = R.call_tir(cls.reshape5, (reshape400,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv505 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape401), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape402 = R.call_tir(cls.reshape6, (lv505,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape403 = R.call_tir(cls.reshape7, (reshape402,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv193_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_28_self_attn_o_proj_q_weight3, model_layers_28_self_attn_o_proj_q_scale3, reshape403), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv256 = R.call_tir(cls.fuse_add_norm_prefill, (lv193_1, lv255, model_layers_28_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv257: R.Tensor((1, seq_len, 2048), dtype="float16") = lv256[1]
            rms_norm203: R.Tensor((1, seq_len, 2048), dtype="float16") = lv256[0]
            lv194_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_28_mlp_gate_up_proj_q_weight3, model_layers_28_mlp_gate_up_proj_q_scale3, rms_norm203), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv129_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv194_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv195_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_28_mlp_down_proj_q_weight3, model_layers_28_mlp_down_proj_q_scale3, lv129_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv258 = R.call_tir(cls.fuse_add_norm_prefill, (lv195_1, lv257, model_layers_29_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv259: R.Tensor((1, seq_len, 2048), dtype="float16") = lv258[1]
            rms_norm204: R.Tensor((1, seq_len, 2048), dtype="float16") = lv258[0]
            lv65 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_29_self_attn_c_attn_q_weight3, model_layers_29_self_attn_c_attn_q_scale3, rms_norm204, model_layers_29_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape404 = R.call_tir(cls.reshape4, (lv65,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape405 = R.call_tir(cls.reshape5, (reshape404,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv510 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape405), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape406 = R.call_tir(cls.reshape6, (lv510,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape407 = R.call_tir(cls.reshape7, (reshape406,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv196_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_29_self_attn_o_proj_q_weight3, model_layers_29_self_attn_o_proj_q_scale3, reshape407), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv260 = R.call_tir(cls.fuse_add_norm_prefill, (lv196_1, lv259, model_layers_29_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv261: R.Tensor((1, seq_len, 2048), dtype="float16") = lv260[1]
            rms_norm205: R.Tensor((1, seq_len, 2048), dtype="float16") = lv260[0]
            lv197_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_29_mlp_gate_up_proj_q_weight3, model_layers_29_mlp_gate_up_proj_q_scale3, rms_norm205), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv131_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv197_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv198_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_29_mlp_down_proj_q_weight3, model_layers_29_mlp_down_proj_q_scale3, lv131_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv262 = R.call_tir(cls.fuse_add_norm_prefill, (lv198_1, lv261, model_layers_30_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv263: R.Tensor((1, seq_len, 2048), dtype="float16") = lv262[1]
            rms_norm206: R.Tensor((1, seq_len, 2048), dtype="float16") = lv262[0]
            lv66 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_30_self_attn_c_attn_q_weight3, model_layers_30_self_attn_c_attn_q_scale3, rms_norm206, model_layers_30_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape408 = R.call_tir(cls.reshape4, (lv66,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape409 = R.call_tir(cls.reshape5, (reshape408,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv515 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape409), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape410 = R.call_tir(cls.reshape6, (lv515,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape411 = R.call_tir(cls.reshape7, (reshape410,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv199_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_30_self_attn_o_proj_q_weight3, model_layers_30_self_attn_o_proj_q_scale3, reshape411), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv264 = R.call_tir(cls.fuse_add_norm_prefill, (lv199_1, lv263, model_layers_30_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv265: R.Tensor((1, seq_len, 2048), dtype="float16") = lv264[1]
            rms_norm207: R.Tensor((1, seq_len, 2048), dtype="float16") = lv264[0]
            lv200_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_30_mlp_gate_up_proj_q_weight3, model_layers_30_mlp_gate_up_proj_q_scale3, rms_norm207), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv133_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv200_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv201_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_30_mlp_down_proj_q_weight3, model_layers_30_mlp_down_proj_q_scale3, lv133_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv266 = R.call_tir(cls.fuse_add_norm_prefill, (lv201_1, lv265, model_layers_31_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv267: R.Tensor((1, seq_len, 2048), dtype="float16") = lv266[1]
            rms_norm208: R.Tensor((1, seq_len, 2048), dtype="float16") = lv266[0]
            lv67 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_31_self_attn_c_attn_q_weight3, model_layers_31_self_attn_c_attn_q_scale3, rms_norm208, model_layers_31_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape412 = R.call_tir(cls.reshape4, (lv67,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape413 = R.call_tir(cls.reshape5, (reshape412,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv520 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape413), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape414 = R.call_tir(cls.reshape6, (lv520,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape415 = R.call_tir(cls.reshape7, (reshape414,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv202_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_31_self_attn_o_proj_q_weight3, model_layers_31_self_attn_o_proj_q_scale3, reshape415), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv268 = R.call_tir(cls.fuse_add_norm_prefill, (lv202_1, lv267, model_layers_31_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv269: R.Tensor((1, seq_len, 2048), dtype="float16") = lv268[1]
            rms_norm209: R.Tensor((1, seq_len, 2048), dtype="float16") = lv268[0]
            lv203_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_31_mlp_gate_up_proj_q_weight3, model_layers_31_mlp_gate_up_proj_q_scale3, rms_norm209), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv135_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv203_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv204_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_31_mlp_down_proj_q_weight3, model_layers_31_mlp_down_proj_q_scale3, lv135_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv270 = R.call_tir(cls.fuse_add_norm_prefill, (lv204_1, lv269, model_layers_32_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv271: R.Tensor((1, seq_len, 2048), dtype="float16") = lv270[1]
            rms_norm210: R.Tensor((1, seq_len, 2048), dtype="float16") = lv270[0]
            lv68 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_32_self_attn_c_attn_q_weight3, model_layers_32_self_attn_c_attn_q_scale3, rms_norm210, model_layers_32_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape416 = R.call_tir(cls.reshape4, (lv68,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape417 = R.call_tir(cls.reshape5, (reshape416,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv525 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape417), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape418 = R.call_tir(cls.reshape6, (lv525,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape419 = R.call_tir(cls.reshape7, (reshape418,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv205_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_32_self_attn_o_proj_q_weight3, model_layers_32_self_attn_o_proj_q_scale3, reshape419), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv272 = R.call_tir(cls.fuse_add_norm_prefill, (lv205_1, lv271, model_layers_32_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv273: R.Tensor((1, seq_len, 2048), dtype="float16") = lv272[1]
            rms_norm211: R.Tensor((1, seq_len, 2048), dtype="float16") = lv272[0]
            lv206_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_32_mlp_gate_up_proj_q_weight3, model_layers_32_mlp_gate_up_proj_q_scale3, rms_norm211), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv137_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv206_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv207_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_32_mlp_down_proj_q_weight3, model_layers_32_mlp_down_proj_q_scale3, lv137_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv274 = R.call_tir(cls.fuse_add_norm_prefill, (lv207_1, lv273, model_layers_33_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv275: R.Tensor((1, seq_len, 2048), dtype="float16") = lv274[1]
            rms_norm212: R.Tensor((1, seq_len, 2048), dtype="float16") = lv274[0]
            lv69 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_33_self_attn_c_attn_q_weight3, model_layers_33_self_attn_c_attn_q_scale3, rms_norm212, model_layers_33_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape420 = R.call_tir(cls.reshape4, (lv69,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape421 = R.call_tir(cls.reshape5, (reshape420,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv530 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape421), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape422 = R.call_tir(cls.reshape6, (lv530,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape423 = R.call_tir(cls.reshape7, (reshape422,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv208_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_33_self_attn_o_proj_q_weight3, model_layers_33_self_attn_o_proj_q_scale3, reshape423), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv276 = R.call_tir(cls.fuse_add_norm_prefill, (lv208_1, lv275, model_layers_33_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv277: R.Tensor((1, seq_len, 2048), dtype="float16") = lv276[1]
            rms_norm213: R.Tensor((1, seq_len, 2048), dtype="float16") = lv276[0]
            lv209_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_33_mlp_gate_up_proj_q_weight3, model_layers_33_mlp_gate_up_proj_q_scale3, rms_norm213), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv139_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv209_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv210_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_33_mlp_down_proj_q_weight3, model_layers_33_mlp_down_proj_q_scale3, lv139_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv278 = R.call_tir(cls.fuse_add_norm_prefill, (lv210_1, lv277, model_layers_34_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv279: R.Tensor((1, seq_len, 2048), dtype="float16") = lv278[1]
            rms_norm214: R.Tensor((1, seq_len, 2048), dtype="float16") = lv278[0]
            lv70 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_34_self_attn_c_attn_q_weight3, model_layers_34_self_attn_c_attn_q_scale3, rms_norm214, model_layers_34_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape424 = R.call_tir(cls.reshape4, (lv70,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape425 = R.call_tir(cls.reshape5, (reshape424,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv535 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape425), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape426 = R.call_tir(cls.reshape6, (lv535,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape427 = R.call_tir(cls.reshape7, (reshape426,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv211_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_34_self_attn_o_proj_q_weight3, model_layers_34_self_attn_o_proj_q_scale3, reshape427), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv280 = R.call_tir(cls.fuse_add_norm_prefill, (lv211_1, lv279, model_layers_34_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv281: R.Tensor((1, seq_len, 2048), dtype="float16") = lv280[1]
            rms_norm215: R.Tensor((1, seq_len, 2048), dtype="float16") = lv280[0]
            lv212_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_34_mlp_gate_up_proj_q_weight3, model_layers_34_mlp_gate_up_proj_q_scale3, rms_norm215), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv141_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv212_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv213_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_34_mlp_down_proj_q_weight3, model_layers_34_mlp_down_proj_q_scale3, lv141_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv282 = R.call_tir(cls.fuse_add_norm_prefill, (lv213_1, lv281, model_layers_35_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv283: R.Tensor((1, seq_len, 2048), dtype="float16") = lv282[1]
            rms_norm216: R.Tensor((1, seq_len, 2048), dtype="float16") = lv282[0]
            lv71 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_35_self_attn_c_attn_q_weight3, model_layers_35_self_attn_c_attn_q_scale3, rms_norm216, model_layers_35_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape428 = R.call_tir(cls.reshape4, (lv71,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape429 = R.call_tir(cls.reshape5, (reshape428,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv540 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape429), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape430 = R.call_tir(cls.reshape6, (lv540,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape431 = R.call_tir(cls.reshape7, (reshape430,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv214_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_35_self_attn_o_proj_q_weight3, model_layers_35_self_attn_o_proj_q_scale3, reshape431), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv284 = R.call_tir(cls.fuse_add_norm_prefill, (lv214_1, lv283, model_layers_35_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv285: R.Tensor((1, seq_len, 2048), dtype="float16") = lv284[1]
            rms_norm217: R.Tensor((1, seq_len, 2048), dtype="float16") = lv284[0]
            lv215_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_35_mlp_gate_up_proj_q_weight3, model_layers_35_mlp_gate_up_proj_q_scale3, rms_norm217), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv143_1 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv215_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv216_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_35_mlp_down_proj_q_weight3, model_layers_35_mlp_down_proj_q_scale3, lv143_1), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv286 = R.call_tir(cls.fuse_add_norm_prefill, (lv216_1, lv285, model_norm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            rms_norm218: R.Tensor((1, seq_len, 2048), dtype="float16") = lv286[0]
            take1 = R.call_tir(cls.take, (rms_norm218, logit_positions), out_sinfo=R.Tensor((1, batch_size, 2048), dtype="float16"))
            lv217_1 = R.call_tir(cls.fused_dequantize_NT_matmul9, (model_embed_tokens_q_weight3, model_embed_tokens_q_scale3, take1), out_sinfo=R.Tensor((1, batch_size, 151936), dtype="float32"))
            gv3: R.Tuple(R.Tensor((1, batch_size, 151936), dtype="float32"), R.Object) = lv217_1, paged_kv_cache
            R.output(gv3)
        return gv3

    @R.function
    def batch_verify(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", 151936), dtype="float32"), R.Object):
        seq_len = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight5: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale5: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight5: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale5: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias5: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight5: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale5: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight5: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale5: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm292 = R.call_tir(cls.rms_norm1, (input_embeds, model_layers_0_input_layernorm_weight5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv72 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_0_self_attn_c_attn_q_weight5, model_layers_0_self_attn_c_attn_q_scale5, rms_norm292, model_layers_0_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape576 = R.call_tir(cls.reshape4, (lv72,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape577 = R.call_tir(cls.reshape5, (reshape576,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv727 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape577), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape578 = R.call_tir(cls.reshape6, (lv727,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape579 = R.call_tir(cls.reshape7, (reshape578,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv218 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_0_self_attn_o_proj_q_weight5, model_layers_0_self_attn_o_proj_q_scale5, reshape579), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv288 = R.call_tir(cls.fuse_add_norm_prefill, (lv218, input_embeds, model_layers_0_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv289: R.Tensor((1, seq_len, 2048), dtype="float16") = lv288[1]
            rms_norm293: R.Tensor((1, seq_len, 2048), dtype="float16") = lv288[0]
            lv219 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_0_mlp_gate_up_proj_q_weight5, model_layers_0_mlp_gate_up_proj_q_scale5, rms_norm293), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv145 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv219,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv220 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_0_mlp_down_proj_q_weight5, model_layers_0_mlp_down_proj_q_scale5, lv145), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv290 = R.call_tir(cls.fuse_add_norm_prefill, (lv220, lv289, model_layers_1_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv291: R.Tensor((1, seq_len, 2048), dtype="float16") = lv290[1]
            rms_norm294: R.Tensor((1, seq_len, 2048), dtype="float16") = lv290[0]
            lv73 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_1_self_attn_c_attn_q_weight5, model_layers_1_self_attn_c_attn_q_scale5, rms_norm294, model_layers_1_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape580 = R.call_tir(cls.reshape4, (lv73,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape581 = R.call_tir(cls.reshape5, (reshape580,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv732 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape581), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape582 = R.call_tir(cls.reshape6, (lv732,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape583 = R.call_tir(cls.reshape7, (reshape582,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv221 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_1_self_attn_o_proj_q_weight5, model_layers_1_self_attn_o_proj_q_scale5, reshape583), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv292 = R.call_tir(cls.fuse_add_norm_prefill, (lv221, lv291, model_layers_1_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv293: R.Tensor((1, seq_len, 2048), dtype="float16") = lv292[1]
            rms_norm295: R.Tensor((1, seq_len, 2048), dtype="float16") = lv292[0]
            lv222 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_1_mlp_gate_up_proj_q_weight5, model_layers_1_mlp_gate_up_proj_q_scale5, rms_norm295), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv147 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv222,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv223 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_1_mlp_down_proj_q_weight5, model_layers_1_mlp_down_proj_q_scale5, lv147), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv294 = R.call_tir(cls.fuse_add_norm_prefill, (lv223, lv293, model_layers_2_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv295: R.Tensor((1, seq_len, 2048), dtype="float16") = lv294[1]
            rms_norm296: R.Tensor((1, seq_len, 2048), dtype="float16") = lv294[0]
            lv74 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_2_self_attn_c_attn_q_weight5, model_layers_2_self_attn_c_attn_q_scale5, rms_norm296, model_layers_2_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape584 = R.call_tir(cls.reshape4, (lv74,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape585 = R.call_tir(cls.reshape5, (reshape584,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv737 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape585), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape586 = R.call_tir(cls.reshape6, (lv737,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape587 = R.call_tir(cls.reshape7, (reshape586,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv224 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_2_self_attn_o_proj_q_weight5, model_layers_2_self_attn_o_proj_q_scale5, reshape587), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv296 = R.call_tir(cls.fuse_add_norm_prefill, (lv224, lv295, model_layers_2_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv297: R.Tensor((1, seq_len, 2048), dtype="float16") = lv296[1]
            rms_norm297: R.Tensor((1, seq_len, 2048), dtype="float16") = lv296[0]
            lv225 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_2_mlp_gate_up_proj_q_weight5, model_layers_2_mlp_gate_up_proj_q_scale5, rms_norm297), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv149 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv225,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv226 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_2_mlp_down_proj_q_weight5, model_layers_2_mlp_down_proj_q_scale5, lv149), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv298 = R.call_tir(cls.fuse_add_norm_prefill, (lv226, lv297, model_layers_3_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv299: R.Tensor((1, seq_len, 2048), dtype="float16") = lv298[1]
            rms_norm298: R.Tensor((1, seq_len, 2048), dtype="float16") = lv298[0]
            lv75 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_3_self_attn_c_attn_q_weight5, model_layers_3_self_attn_c_attn_q_scale5, rms_norm298, model_layers_3_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape588 = R.call_tir(cls.reshape4, (lv75,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape589 = R.call_tir(cls.reshape5, (reshape588,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv742 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape589), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape590 = R.call_tir(cls.reshape6, (lv742,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape591 = R.call_tir(cls.reshape7, (reshape590,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv227 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_3_self_attn_o_proj_q_weight5, model_layers_3_self_attn_o_proj_q_scale5, reshape591), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv300 = R.call_tir(cls.fuse_add_norm_prefill, (lv227, lv299, model_layers_3_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv301: R.Tensor((1, seq_len, 2048), dtype="float16") = lv300[1]
            rms_norm299: R.Tensor((1, seq_len, 2048), dtype="float16") = lv300[0]
            lv228 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_3_mlp_gate_up_proj_q_weight5, model_layers_3_mlp_gate_up_proj_q_scale5, rms_norm299), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv151 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv228,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv229 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_3_mlp_down_proj_q_weight5, model_layers_3_mlp_down_proj_q_scale5, lv151), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv302 = R.call_tir(cls.fuse_add_norm_prefill, (lv229, lv301, model_layers_4_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv303: R.Tensor((1, seq_len, 2048), dtype="float16") = lv302[1]
            rms_norm300: R.Tensor((1, seq_len, 2048), dtype="float16") = lv302[0]
            lv76 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_4_self_attn_c_attn_q_weight5, model_layers_4_self_attn_c_attn_q_scale5, rms_norm300, model_layers_4_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape592 = R.call_tir(cls.reshape4, (lv76,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape593 = R.call_tir(cls.reshape5, (reshape592,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv747 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape593), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape594 = R.call_tir(cls.reshape6, (lv747,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape595 = R.call_tir(cls.reshape7, (reshape594,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv230 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_4_self_attn_o_proj_q_weight5, model_layers_4_self_attn_o_proj_q_scale5, reshape595), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv304 = R.call_tir(cls.fuse_add_norm_prefill, (lv230, lv303, model_layers_4_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv305: R.Tensor((1, seq_len, 2048), dtype="float16") = lv304[1]
            rms_norm301: R.Tensor((1, seq_len, 2048), dtype="float16") = lv304[0]
            lv231 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_4_mlp_gate_up_proj_q_weight5, model_layers_4_mlp_gate_up_proj_q_scale5, rms_norm301), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv153 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv231,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv232 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_4_mlp_down_proj_q_weight5, model_layers_4_mlp_down_proj_q_scale5, lv153), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv306 = R.call_tir(cls.fuse_add_norm_prefill, (lv232, lv305, model_layers_5_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv307: R.Tensor((1, seq_len, 2048), dtype="float16") = lv306[1]
            rms_norm302: R.Tensor((1, seq_len, 2048), dtype="float16") = lv306[0]
            lv77 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_5_self_attn_c_attn_q_weight5, model_layers_5_self_attn_c_attn_q_scale5, rms_norm302, model_layers_5_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape596 = R.call_tir(cls.reshape4, (lv77,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape597 = R.call_tir(cls.reshape5, (reshape596,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv752 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape597), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape598 = R.call_tir(cls.reshape6, (lv752,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape599 = R.call_tir(cls.reshape7, (reshape598,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv233 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_5_self_attn_o_proj_q_weight5, model_layers_5_self_attn_o_proj_q_scale5, reshape599), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv308 = R.call_tir(cls.fuse_add_norm_prefill, (lv233, lv307, model_layers_5_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv309: R.Tensor((1, seq_len, 2048), dtype="float16") = lv308[1]
            rms_norm303: R.Tensor((1, seq_len, 2048), dtype="float16") = lv308[0]
            lv234 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_5_mlp_gate_up_proj_q_weight5, model_layers_5_mlp_gate_up_proj_q_scale5, rms_norm303), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv155 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv234,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv235 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_5_mlp_down_proj_q_weight5, model_layers_5_mlp_down_proj_q_scale5, lv155), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv310 = R.call_tir(cls.fuse_add_norm_prefill, (lv235, lv309, model_layers_6_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv311: R.Tensor((1, seq_len, 2048), dtype="float16") = lv310[1]
            rms_norm304: R.Tensor((1, seq_len, 2048), dtype="float16") = lv310[0]
            lv78 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_6_self_attn_c_attn_q_weight5, model_layers_6_self_attn_c_attn_q_scale5, rms_norm304, model_layers_6_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape600 = R.call_tir(cls.reshape4, (lv78,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape601 = R.call_tir(cls.reshape5, (reshape600,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv757 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape601), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape602 = R.call_tir(cls.reshape6, (lv757,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape603 = R.call_tir(cls.reshape7, (reshape602,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv236 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_6_self_attn_o_proj_q_weight5, model_layers_6_self_attn_o_proj_q_scale5, reshape603), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv312 = R.call_tir(cls.fuse_add_norm_prefill, (lv236, lv311, model_layers_6_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv313: R.Tensor((1, seq_len, 2048), dtype="float16") = lv312[1]
            rms_norm305: R.Tensor((1, seq_len, 2048), dtype="float16") = lv312[0]
            lv237 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_6_mlp_gate_up_proj_q_weight5, model_layers_6_mlp_gate_up_proj_q_scale5, rms_norm305), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv157 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv237,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv238 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_6_mlp_down_proj_q_weight5, model_layers_6_mlp_down_proj_q_scale5, lv157), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv314 = R.call_tir(cls.fuse_add_norm_prefill, (lv238, lv313, model_layers_7_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv315: R.Tensor((1, seq_len, 2048), dtype="float16") = lv314[1]
            rms_norm306: R.Tensor((1, seq_len, 2048), dtype="float16") = lv314[0]
            lv79 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_7_self_attn_c_attn_q_weight5, model_layers_7_self_attn_c_attn_q_scale5, rms_norm306, model_layers_7_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape604 = R.call_tir(cls.reshape4, (lv79,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape605 = R.call_tir(cls.reshape5, (reshape604,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv762 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape605), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape606 = R.call_tir(cls.reshape6, (lv762,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape607 = R.call_tir(cls.reshape7, (reshape606,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv239 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_7_self_attn_o_proj_q_weight5, model_layers_7_self_attn_o_proj_q_scale5, reshape607), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv316 = R.call_tir(cls.fuse_add_norm_prefill, (lv239, lv315, model_layers_7_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv317: R.Tensor((1, seq_len, 2048), dtype="float16") = lv316[1]
            rms_norm307: R.Tensor((1, seq_len, 2048), dtype="float16") = lv316[0]
            lv240 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_7_mlp_gate_up_proj_q_weight5, model_layers_7_mlp_gate_up_proj_q_scale5, rms_norm307), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv159 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv240,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv241 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_7_mlp_down_proj_q_weight5, model_layers_7_mlp_down_proj_q_scale5, lv159), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv318 = R.call_tir(cls.fuse_add_norm_prefill, (lv241, lv317, model_layers_8_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv319: R.Tensor((1, seq_len, 2048), dtype="float16") = lv318[1]
            rms_norm308: R.Tensor((1, seq_len, 2048), dtype="float16") = lv318[0]
            lv80 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_8_self_attn_c_attn_q_weight5, model_layers_8_self_attn_c_attn_q_scale5, rms_norm308, model_layers_8_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape608 = R.call_tir(cls.reshape4, (lv80,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape609 = R.call_tir(cls.reshape5, (reshape608,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv767 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape609), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape610 = R.call_tir(cls.reshape6, (lv767,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape611 = R.call_tir(cls.reshape7, (reshape610,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv242 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_8_self_attn_o_proj_q_weight5, model_layers_8_self_attn_o_proj_q_scale5, reshape611), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv320 = R.call_tir(cls.fuse_add_norm_prefill, (lv242, lv319, model_layers_8_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv321: R.Tensor((1, seq_len, 2048), dtype="float16") = lv320[1]
            rms_norm309: R.Tensor((1, seq_len, 2048), dtype="float16") = lv320[0]
            lv243 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_8_mlp_gate_up_proj_q_weight5, model_layers_8_mlp_gate_up_proj_q_scale5, rms_norm309), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv161 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv243,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv244 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_8_mlp_down_proj_q_weight5, model_layers_8_mlp_down_proj_q_scale5, lv161), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv322 = R.call_tir(cls.fuse_add_norm_prefill, (lv244, lv321, model_layers_9_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv323: R.Tensor((1, seq_len, 2048), dtype="float16") = lv322[1]
            rms_norm310: R.Tensor((1, seq_len, 2048), dtype="float16") = lv322[0]
            lv81 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_9_self_attn_c_attn_q_weight5, model_layers_9_self_attn_c_attn_q_scale5, rms_norm310, model_layers_9_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape612 = R.call_tir(cls.reshape4, (lv81,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape613 = R.call_tir(cls.reshape5, (reshape612,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv772 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape613), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape614 = R.call_tir(cls.reshape6, (lv772,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape615 = R.call_tir(cls.reshape7, (reshape614,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv245 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_9_self_attn_o_proj_q_weight5, model_layers_9_self_attn_o_proj_q_scale5, reshape615), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv324 = R.call_tir(cls.fuse_add_norm_prefill, (lv245, lv323, model_layers_9_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv325: R.Tensor((1, seq_len, 2048), dtype="float16") = lv324[1]
            rms_norm311: R.Tensor((1, seq_len, 2048), dtype="float16") = lv324[0]
            lv246 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_9_mlp_gate_up_proj_q_weight5, model_layers_9_mlp_gate_up_proj_q_scale5, rms_norm311), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv163 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv246,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv247 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_9_mlp_down_proj_q_weight5, model_layers_9_mlp_down_proj_q_scale5, lv163), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv326 = R.call_tir(cls.fuse_add_norm_prefill, (lv247, lv325, model_layers_10_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv327: R.Tensor((1, seq_len, 2048), dtype="float16") = lv326[1]
            rms_norm312: R.Tensor((1, seq_len, 2048), dtype="float16") = lv326[0]
            lv82 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_10_self_attn_c_attn_q_weight5, model_layers_10_self_attn_c_attn_q_scale5, rms_norm312, model_layers_10_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape616 = R.call_tir(cls.reshape4, (lv82,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape617 = R.call_tir(cls.reshape5, (reshape616,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv777 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape617), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape618 = R.call_tir(cls.reshape6, (lv777,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape619 = R.call_tir(cls.reshape7, (reshape618,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv248 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_10_self_attn_o_proj_q_weight5, model_layers_10_self_attn_o_proj_q_scale5, reshape619), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv328 = R.call_tir(cls.fuse_add_norm_prefill, (lv248, lv327, model_layers_10_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv329: R.Tensor((1, seq_len, 2048), dtype="float16") = lv328[1]
            rms_norm313: R.Tensor((1, seq_len, 2048), dtype="float16") = lv328[0]
            lv249 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_10_mlp_gate_up_proj_q_weight5, model_layers_10_mlp_gate_up_proj_q_scale5, rms_norm313), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv165 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv249,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv250 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_10_mlp_down_proj_q_weight5, model_layers_10_mlp_down_proj_q_scale5, lv165), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv330 = R.call_tir(cls.fuse_add_norm_prefill, (lv250, lv329, model_layers_11_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv331: R.Tensor((1, seq_len, 2048), dtype="float16") = lv330[1]
            rms_norm314: R.Tensor((1, seq_len, 2048), dtype="float16") = lv330[0]
            lv83 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_11_self_attn_c_attn_q_weight5, model_layers_11_self_attn_c_attn_q_scale5, rms_norm314, model_layers_11_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape620 = R.call_tir(cls.reshape4, (lv83,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape621 = R.call_tir(cls.reshape5, (reshape620,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv782 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape621), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape622 = R.call_tir(cls.reshape6, (lv782,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape623 = R.call_tir(cls.reshape7, (reshape622,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv251 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_11_self_attn_o_proj_q_weight5, model_layers_11_self_attn_o_proj_q_scale5, reshape623), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv332 = R.call_tir(cls.fuse_add_norm_prefill, (lv251, lv331, model_layers_11_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv333: R.Tensor((1, seq_len, 2048), dtype="float16") = lv332[1]
            rms_norm315: R.Tensor((1, seq_len, 2048), dtype="float16") = lv332[0]
            lv252 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_11_mlp_gate_up_proj_q_weight5, model_layers_11_mlp_gate_up_proj_q_scale5, rms_norm315), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv167 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv252,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv253 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_11_mlp_down_proj_q_weight5, model_layers_11_mlp_down_proj_q_scale5, lv167), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv334 = R.call_tir(cls.fuse_add_norm_prefill, (lv253, lv333, model_layers_12_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv335: R.Tensor((1, seq_len, 2048), dtype="float16") = lv334[1]
            rms_norm316: R.Tensor((1, seq_len, 2048), dtype="float16") = lv334[0]
            lv84 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_12_self_attn_c_attn_q_weight5, model_layers_12_self_attn_c_attn_q_scale5, rms_norm316, model_layers_12_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape624 = R.call_tir(cls.reshape4, (lv84,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape625 = R.call_tir(cls.reshape5, (reshape624,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv787 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape625), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape626 = R.call_tir(cls.reshape6, (lv787,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape627 = R.call_tir(cls.reshape7, (reshape626,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv254 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_12_self_attn_o_proj_q_weight5, model_layers_12_self_attn_o_proj_q_scale5, reshape627), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv336 = R.call_tir(cls.fuse_add_norm_prefill, (lv254, lv335, model_layers_12_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv337: R.Tensor((1, seq_len, 2048), dtype="float16") = lv336[1]
            rms_norm317: R.Tensor((1, seq_len, 2048), dtype="float16") = lv336[0]
            lv255 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_12_mlp_gate_up_proj_q_weight5, model_layers_12_mlp_gate_up_proj_q_scale5, rms_norm317), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv169 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv255,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv256 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_12_mlp_down_proj_q_weight5, model_layers_12_mlp_down_proj_q_scale5, lv169), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv338 = R.call_tir(cls.fuse_add_norm_prefill, (lv256, lv337, model_layers_13_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv339: R.Tensor((1, seq_len, 2048), dtype="float16") = lv338[1]
            rms_norm318: R.Tensor((1, seq_len, 2048), dtype="float16") = lv338[0]
            lv85 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_13_self_attn_c_attn_q_weight5, model_layers_13_self_attn_c_attn_q_scale5, rms_norm318, model_layers_13_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape628 = R.call_tir(cls.reshape4, (lv85,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape629 = R.call_tir(cls.reshape5, (reshape628,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv792 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape629), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape630 = R.call_tir(cls.reshape6, (lv792,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape631 = R.call_tir(cls.reshape7, (reshape630,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv257 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_13_self_attn_o_proj_q_weight5, model_layers_13_self_attn_o_proj_q_scale5, reshape631), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv340 = R.call_tir(cls.fuse_add_norm_prefill, (lv257, lv339, model_layers_13_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv341: R.Tensor((1, seq_len, 2048), dtype="float16") = lv340[1]
            rms_norm319: R.Tensor((1, seq_len, 2048), dtype="float16") = lv340[0]
            lv258 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_13_mlp_gate_up_proj_q_weight5, model_layers_13_mlp_gate_up_proj_q_scale5, rms_norm319), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv171 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv258,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv259 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_13_mlp_down_proj_q_weight5, model_layers_13_mlp_down_proj_q_scale5, lv171), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv342 = R.call_tir(cls.fuse_add_norm_prefill, (lv259, lv341, model_layers_14_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv343: R.Tensor((1, seq_len, 2048), dtype="float16") = lv342[1]
            rms_norm320: R.Tensor((1, seq_len, 2048), dtype="float16") = lv342[0]
            lv86 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_14_self_attn_c_attn_q_weight5, model_layers_14_self_attn_c_attn_q_scale5, rms_norm320, model_layers_14_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape632 = R.call_tir(cls.reshape4, (lv86,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape633 = R.call_tir(cls.reshape5, (reshape632,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv797 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape633), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape634 = R.call_tir(cls.reshape6, (lv797,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape635 = R.call_tir(cls.reshape7, (reshape634,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv260 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_14_self_attn_o_proj_q_weight5, model_layers_14_self_attn_o_proj_q_scale5, reshape635), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv344 = R.call_tir(cls.fuse_add_norm_prefill, (lv260, lv343, model_layers_14_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv345: R.Tensor((1, seq_len, 2048), dtype="float16") = lv344[1]
            rms_norm321: R.Tensor((1, seq_len, 2048), dtype="float16") = lv344[0]
            lv261 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_14_mlp_gate_up_proj_q_weight5, model_layers_14_mlp_gate_up_proj_q_scale5, rms_norm321), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv173 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv261,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv262 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_14_mlp_down_proj_q_weight5, model_layers_14_mlp_down_proj_q_scale5, lv173), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv346 = R.call_tir(cls.fuse_add_norm_prefill, (lv262, lv345, model_layers_15_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv347: R.Tensor((1, seq_len, 2048), dtype="float16") = lv346[1]
            rms_norm322: R.Tensor((1, seq_len, 2048), dtype="float16") = lv346[0]
            lv87 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_15_self_attn_c_attn_q_weight5, model_layers_15_self_attn_c_attn_q_scale5, rms_norm322, model_layers_15_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape636 = R.call_tir(cls.reshape4, (lv87,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape637 = R.call_tir(cls.reshape5, (reshape636,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv802 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape637), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape638 = R.call_tir(cls.reshape6, (lv802,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape639 = R.call_tir(cls.reshape7, (reshape638,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv263 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_15_self_attn_o_proj_q_weight5, model_layers_15_self_attn_o_proj_q_scale5, reshape639), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv348 = R.call_tir(cls.fuse_add_norm_prefill, (lv263, lv347, model_layers_15_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv349: R.Tensor((1, seq_len, 2048), dtype="float16") = lv348[1]
            rms_norm323: R.Tensor((1, seq_len, 2048), dtype="float16") = lv348[0]
            lv264 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_15_mlp_gate_up_proj_q_weight5, model_layers_15_mlp_gate_up_proj_q_scale5, rms_norm323), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv175 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv264,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv265 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_15_mlp_down_proj_q_weight5, model_layers_15_mlp_down_proj_q_scale5, lv175), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv350 = R.call_tir(cls.fuse_add_norm_prefill, (lv265, lv349, model_layers_16_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv351: R.Tensor((1, seq_len, 2048), dtype="float16") = lv350[1]
            rms_norm324: R.Tensor((1, seq_len, 2048), dtype="float16") = lv350[0]
            lv88 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_16_self_attn_c_attn_q_weight5, model_layers_16_self_attn_c_attn_q_scale5, rms_norm324, model_layers_16_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape640 = R.call_tir(cls.reshape4, (lv88,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape641 = R.call_tir(cls.reshape5, (reshape640,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv807 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape641), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape642 = R.call_tir(cls.reshape6, (lv807,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape643 = R.call_tir(cls.reshape7, (reshape642,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv266 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_16_self_attn_o_proj_q_weight5, model_layers_16_self_attn_o_proj_q_scale5, reshape643), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv352 = R.call_tir(cls.fuse_add_norm_prefill, (lv266, lv351, model_layers_16_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv353: R.Tensor((1, seq_len, 2048), dtype="float16") = lv352[1]
            rms_norm325: R.Tensor((1, seq_len, 2048), dtype="float16") = lv352[0]
            lv267 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_16_mlp_gate_up_proj_q_weight5, model_layers_16_mlp_gate_up_proj_q_scale5, rms_norm325), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv177 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv267,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv268 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_16_mlp_down_proj_q_weight5, model_layers_16_mlp_down_proj_q_scale5, lv177), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv354 = R.call_tir(cls.fuse_add_norm_prefill, (lv268, lv353, model_layers_17_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv355: R.Tensor((1, seq_len, 2048), dtype="float16") = lv354[1]
            rms_norm326: R.Tensor((1, seq_len, 2048), dtype="float16") = lv354[0]
            lv89 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_17_self_attn_c_attn_q_weight5, model_layers_17_self_attn_c_attn_q_scale5, rms_norm326, model_layers_17_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape644 = R.call_tir(cls.reshape4, (lv89,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape645 = R.call_tir(cls.reshape5, (reshape644,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv812 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape645), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape646 = R.call_tir(cls.reshape6, (lv812,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape647 = R.call_tir(cls.reshape7, (reshape646,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv269 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_17_self_attn_o_proj_q_weight5, model_layers_17_self_attn_o_proj_q_scale5, reshape647), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv356 = R.call_tir(cls.fuse_add_norm_prefill, (lv269, lv355, model_layers_17_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv357: R.Tensor((1, seq_len, 2048), dtype="float16") = lv356[1]
            rms_norm327: R.Tensor((1, seq_len, 2048), dtype="float16") = lv356[0]
            lv270 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_17_mlp_gate_up_proj_q_weight5, model_layers_17_mlp_gate_up_proj_q_scale5, rms_norm327), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv179 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv270,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv271 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_17_mlp_down_proj_q_weight5, model_layers_17_mlp_down_proj_q_scale5, lv179), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv358 = R.call_tir(cls.fuse_add_norm_prefill, (lv271, lv357, model_layers_18_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv359: R.Tensor((1, seq_len, 2048), dtype="float16") = lv358[1]
            rms_norm328: R.Tensor((1, seq_len, 2048), dtype="float16") = lv358[0]
            lv90 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_18_self_attn_c_attn_q_weight5, model_layers_18_self_attn_c_attn_q_scale5, rms_norm328, model_layers_18_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape648 = R.call_tir(cls.reshape4, (lv90,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape649 = R.call_tir(cls.reshape5, (reshape648,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv817 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape649), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape650 = R.call_tir(cls.reshape6, (lv817,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape651 = R.call_tir(cls.reshape7, (reshape650,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv272 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_18_self_attn_o_proj_q_weight5, model_layers_18_self_attn_o_proj_q_scale5, reshape651), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv360 = R.call_tir(cls.fuse_add_norm_prefill, (lv272, lv359, model_layers_18_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv361: R.Tensor((1, seq_len, 2048), dtype="float16") = lv360[1]
            rms_norm329: R.Tensor((1, seq_len, 2048), dtype="float16") = lv360[0]
            lv273 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_18_mlp_gate_up_proj_q_weight5, model_layers_18_mlp_gate_up_proj_q_scale5, rms_norm329), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv181 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv273,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv274 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_18_mlp_down_proj_q_weight5, model_layers_18_mlp_down_proj_q_scale5, lv181), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv362 = R.call_tir(cls.fuse_add_norm_prefill, (lv274, lv361, model_layers_19_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv363: R.Tensor((1, seq_len, 2048), dtype="float16") = lv362[1]
            rms_norm330: R.Tensor((1, seq_len, 2048), dtype="float16") = lv362[0]
            lv91 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_19_self_attn_c_attn_q_weight5, model_layers_19_self_attn_c_attn_q_scale5, rms_norm330, model_layers_19_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape652 = R.call_tir(cls.reshape4, (lv91,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape653 = R.call_tir(cls.reshape5, (reshape652,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv822 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape653), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape654 = R.call_tir(cls.reshape6, (lv822,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape655 = R.call_tir(cls.reshape7, (reshape654,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv275 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_19_self_attn_o_proj_q_weight5, model_layers_19_self_attn_o_proj_q_scale5, reshape655), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv364 = R.call_tir(cls.fuse_add_norm_prefill, (lv275, lv363, model_layers_19_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv365: R.Tensor((1, seq_len, 2048), dtype="float16") = lv364[1]
            rms_norm331: R.Tensor((1, seq_len, 2048), dtype="float16") = lv364[0]
            lv276 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_19_mlp_gate_up_proj_q_weight5, model_layers_19_mlp_gate_up_proj_q_scale5, rms_norm331), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv183 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv276,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv277 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_19_mlp_down_proj_q_weight5, model_layers_19_mlp_down_proj_q_scale5, lv183), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv366 = R.call_tir(cls.fuse_add_norm_prefill, (lv277, lv365, model_layers_20_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv367: R.Tensor((1, seq_len, 2048), dtype="float16") = lv366[1]
            rms_norm332: R.Tensor((1, seq_len, 2048), dtype="float16") = lv366[0]
            lv92 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_20_self_attn_c_attn_q_weight5, model_layers_20_self_attn_c_attn_q_scale5, rms_norm332, model_layers_20_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape656 = R.call_tir(cls.reshape4, (lv92,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape657 = R.call_tir(cls.reshape5, (reshape656,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv827 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape657), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape658 = R.call_tir(cls.reshape6, (lv827,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape659 = R.call_tir(cls.reshape7, (reshape658,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv278 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_20_self_attn_o_proj_q_weight5, model_layers_20_self_attn_o_proj_q_scale5, reshape659), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv368 = R.call_tir(cls.fuse_add_norm_prefill, (lv278, lv367, model_layers_20_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv369: R.Tensor((1, seq_len, 2048), dtype="float16") = lv368[1]
            rms_norm333: R.Tensor((1, seq_len, 2048), dtype="float16") = lv368[0]
            lv279 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_20_mlp_gate_up_proj_q_weight5, model_layers_20_mlp_gate_up_proj_q_scale5, rms_norm333), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv185 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv279,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv280 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_20_mlp_down_proj_q_weight5, model_layers_20_mlp_down_proj_q_scale5, lv185), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv370 = R.call_tir(cls.fuse_add_norm_prefill, (lv280, lv369, model_layers_21_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv371: R.Tensor((1, seq_len, 2048), dtype="float16") = lv370[1]
            rms_norm334: R.Tensor((1, seq_len, 2048), dtype="float16") = lv370[0]
            lv93 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_21_self_attn_c_attn_q_weight5, model_layers_21_self_attn_c_attn_q_scale5, rms_norm334, model_layers_21_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape660 = R.call_tir(cls.reshape4, (lv93,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape661 = R.call_tir(cls.reshape5, (reshape660,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv832 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape661), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape662 = R.call_tir(cls.reshape6, (lv832,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape663 = R.call_tir(cls.reshape7, (reshape662,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv281 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_21_self_attn_o_proj_q_weight5, model_layers_21_self_attn_o_proj_q_scale5, reshape663), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv372 = R.call_tir(cls.fuse_add_norm_prefill, (lv281, lv371, model_layers_21_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv373: R.Tensor((1, seq_len, 2048), dtype="float16") = lv372[1]
            rms_norm335: R.Tensor((1, seq_len, 2048), dtype="float16") = lv372[0]
            lv282 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_21_mlp_gate_up_proj_q_weight5, model_layers_21_mlp_gate_up_proj_q_scale5, rms_norm335), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv187 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv282,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv283 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_21_mlp_down_proj_q_weight5, model_layers_21_mlp_down_proj_q_scale5, lv187), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv374 = R.call_tir(cls.fuse_add_norm_prefill, (lv283, lv373, model_layers_22_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv375: R.Tensor((1, seq_len, 2048), dtype="float16") = lv374[1]
            rms_norm336: R.Tensor((1, seq_len, 2048), dtype="float16") = lv374[0]
            lv94 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_22_self_attn_c_attn_q_weight5, model_layers_22_self_attn_c_attn_q_scale5, rms_norm336, model_layers_22_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape664 = R.call_tir(cls.reshape4, (lv94,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape665 = R.call_tir(cls.reshape5, (reshape664,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv837 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape665), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape666 = R.call_tir(cls.reshape6, (lv837,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape667 = R.call_tir(cls.reshape7, (reshape666,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv284 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_22_self_attn_o_proj_q_weight5, model_layers_22_self_attn_o_proj_q_scale5, reshape667), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv376 = R.call_tir(cls.fuse_add_norm_prefill, (lv284, lv375, model_layers_22_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv377: R.Tensor((1, seq_len, 2048), dtype="float16") = lv376[1]
            rms_norm337: R.Tensor((1, seq_len, 2048), dtype="float16") = lv376[0]
            lv285 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_22_mlp_gate_up_proj_q_weight5, model_layers_22_mlp_gate_up_proj_q_scale5, rms_norm337), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv189 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv285,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv286 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_22_mlp_down_proj_q_weight5, model_layers_22_mlp_down_proj_q_scale5, lv189), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv378 = R.call_tir(cls.fuse_add_norm_prefill, (lv286, lv377, model_layers_23_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv379: R.Tensor((1, seq_len, 2048), dtype="float16") = lv378[1]
            rms_norm338: R.Tensor((1, seq_len, 2048), dtype="float16") = lv378[0]
            lv95 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_23_self_attn_c_attn_q_weight5, model_layers_23_self_attn_c_attn_q_scale5, rms_norm338, model_layers_23_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape668 = R.call_tir(cls.reshape4, (lv95,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape669 = R.call_tir(cls.reshape5, (reshape668,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv842 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape669), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape670 = R.call_tir(cls.reshape6, (lv842,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape671 = R.call_tir(cls.reshape7, (reshape670,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv287 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_23_self_attn_o_proj_q_weight5, model_layers_23_self_attn_o_proj_q_scale5, reshape671), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv380 = R.call_tir(cls.fuse_add_norm_prefill, (lv287, lv379, model_layers_23_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv381: R.Tensor((1, seq_len, 2048), dtype="float16") = lv380[1]
            rms_norm339: R.Tensor((1, seq_len, 2048), dtype="float16") = lv380[0]
            lv288_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_23_mlp_gate_up_proj_q_weight5, model_layers_23_mlp_gate_up_proj_q_scale5, rms_norm339), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv191 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv288_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv289_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_23_mlp_down_proj_q_weight5, model_layers_23_mlp_down_proj_q_scale5, lv191), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv382 = R.call_tir(cls.fuse_add_norm_prefill, (lv289_1, lv381, model_layers_24_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv383: R.Tensor((1, seq_len, 2048), dtype="float16") = lv382[1]
            rms_norm340: R.Tensor((1, seq_len, 2048), dtype="float16") = lv382[0]
            lv96 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_24_self_attn_c_attn_q_weight5, model_layers_24_self_attn_c_attn_q_scale5, rms_norm340, model_layers_24_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape672 = R.call_tir(cls.reshape4, (lv96,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape673 = R.call_tir(cls.reshape5, (reshape672,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv847 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(24), R.prim_value(T.float32(1.0)), reshape673), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape674 = R.call_tir(cls.reshape6, (lv847,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape675 = R.call_tir(cls.reshape7, (reshape674,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv290_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_24_self_attn_o_proj_q_weight5, model_layers_24_self_attn_o_proj_q_scale5, reshape675), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv384 = R.call_tir(cls.fuse_add_norm_prefill, (lv290_1, lv383, model_layers_24_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv385: R.Tensor((1, seq_len, 2048), dtype="float16") = lv384[1]
            rms_norm341: R.Tensor((1, seq_len, 2048), dtype="float16") = lv384[0]
            lv291_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_24_mlp_gate_up_proj_q_weight5, model_layers_24_mlp_gate_up_proj_q_scale5, rms_norm341), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv193 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv291_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv292_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_24_mlp_down_proj_q_weight5, model_layers_24_mlp_down_proj_q_scale5, lv193), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv386 = R.call_tir(cls.fuse_add_norm_prefill, (lv292_1, lv385, model_layers_25_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv387: R.Tensor((1, seq_len, 2048), dtype="float16") = lv386[1]
            rms_norm342: R.Tensor((1, seq_len, 2048), dtype="float16") = lv386[0]
            lv97 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_25_self_attn_c_attn_q_weight5, model_layers_25_self_attn_c_attn_q_scale5, rms_norm342, model_layers_25_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape676 = R.call_tir(cls.reshape4, (lv97,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape677 = R.call_tir(cls.reshape5, (reshape676,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv852 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(25), R.prim_value(T.float32(1.0)), reshape677), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape678 = R.call_tir(cls.reshape6, (lv852,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape679 = R.call_tir(cls.reshape7, (reshape678,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv293_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_25_self_attn_o_proj_q_weight5, model_layers_25_self_attn_o_proj_q_scale5, reshape679), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv388 = R.call_tir(cls.fuse_add_norm_prefill, (lv293_1, lv387, model_layers_25_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv389: R.Tensor((1, seq_len, 2048), dtype="float16") = lv388[1]
            rms_norm343: R.Tensor((1, seq_len, 2048), dtype="float16") = lv388[0]
            lv294_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_25_mlp_gate_up_proj_q_weight5, model_layers_25_mlp_gate_up_proj_q_scale5, rms_norm343), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv195 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv294_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv295_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_25_mlp_down_proj_q_weight5, model_layers_25_mlp_down_proj_q_scale5, lv195), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv390 = R.call_tir(cls.fuse_add_norm_prefill, (lv295_1, lv389, model_layers_26_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv391: R.Tensor((1, seq_len, 2048), dtype="float16") = lv390[1]
            rms_norm344: R.Tensor((1, seq_len, 2048), dtype="float16") = lv390[0]
            lv98 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_26_self_attn_c_attn_q_weight5, model_layers_26_self_attn_c_attn_q_scale5, rms_norm344, model_layers_26_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape680 = R.call_tir(cls.reshape4, (lv98,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape681 = R.call_tir(cls.reshape5, (reshape680,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv857 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(26), R.prim_value(T.float32(1.0)), reshape681), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape682 = R.call_tir(cls.reshape6, (lv857,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape683 = R.call_tir(cls.reshape7, (reshape682,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv296_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_26_self_attn_o_proj_q_weight5, model_layers_26_self_attn_o_proj_q_scale5, reshape683), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv392 = R.call_tir(cls.fuse_add_norm_prefill, (lv296_1, lv391, model_layers_26_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv393: R.Tensor((1, seq_len, 2048), dtype="float16") = lv392[1]
            rms_norm345: R.Tensor((1, seq_len, 2048), dtype="float16") = lv392[0]
            lv297_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_26_mlp_gate_up_proj_q_weight5, model_layers_26_mlp_gate_up_proj_q_scale5, rms_norm345), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv197 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv297_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv298_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_26_mlp_down_proj_q_weight5, model_layers_26_mlp_down_proj_q_scale5, lv197), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv394 = R.call_tir(cls.fuse_add_norm_prefill, (lv298_1, lv393, model_layers_27_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv395: R.Tensor((1, seq_len, 2048), dtype="float16") = lv394[1]
            rms_norm346: R.Tensor((1, seq_len, 2048), dtype="float16") = lv394[0]
            lv99 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_27_self_attn_c_attn_q_weight5, model_layers_27_self_attn_c_attn_q_scale5, rms_norm346, model_layers_27_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape684 = R.call_tir(cls.reshape4, (lv99,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape685 = R.call_tir(cls.reshape5, (reshape684,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv862 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(27), R.prim_value(T.float32(1.0)), reshape685), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape686 = R.call_tir(cls.reshape6, (lv862,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape687 = R.call_tir(cls.reshape7, (reshape686,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv299_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_27_self_attn_o_proj_q_weight5, model_layers_27_self_attn_o_proj_q_scale5, reshape687), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv396 = R.call_tir(cls.fuse_add_norm_prefill, (lv299_1, lv395, model_layers_27_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv397: R.Tensor((1, seq_len, 2048), dtype="float16") = lv396[1]
            rms_norm347: R.Tensor((1, seq_len, 2048), dtype="float16") = lv396[0]
            lv300_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_27_mlp_gate_up_proj_q_weight5, model_layers_27_mlp_gate_up_proj_q_scale5, rms_norm347), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv199 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv300_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv301_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_27_mlp_down_proj_q_weight5, model_layers_27_mlp_down_proj_q_scale5, lv199), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv398 = R.call_tir(cls.fuse_add_norm_prefill, (lv301_1, lv397, model_layers_28_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv399: R.Tensor((1, seq_len, 2048), dtype="float16") = lv398[1]
            rms_norm348: R.Tensor((1, seq_len, 2048), dtype="float16") = lv398[0]
            lv100 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_28_self_attn_c_attn_q_weight5, model_layers_28_self_attn_c_attn_q_scale5, rms_norm348, model_layers_28_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape688 = R.call_tir(cls.reshape4, (lv100,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape689 = R.call_tir(cls.reshape5, (reshape688,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv867 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(28), R.prim_value(T.float32(1.0)), reshape689), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape690 = R.call_tir(cls.reshape6, (lv867,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape691 = R.call_tir(cls.reshape7, (reshape690,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv302_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_28_self_attn_o_proj_q_weight5, model_layers_28_self_attn_o_proj_q_scale5, reshape691), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv400 = R.call_tir(cls.fuse_add_norm_prefill, (lv302_1, lv399, model_layers_28_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv401: R.Tensor((1, seq_len, 2048), dtype="float16") = lv400[1]
            rms_norm349: R.Tensor((1, seq_len, 2048), dtype="float16") = lv400[0]
            lv303_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_28_mlp_gate_up_proj_q_weight5, model_layers_28_mlp_gate_up_proj_q_scale5, rms_norm349), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv201 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv303_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv304_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_28_mlp_down_proj_q_weight5, model_layers_28_mlp_down_proj_q_scale5, lv201), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv402 = R.call_tir(cls.fuse_add_norm_prefill, (lv304_1, lv401, model_layers_29_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv403: R.Tensor((1, seq_len, 2048), dtype="float16") = lv402[1]
            rms_norm350: R.Tensor((1, seq_len, 2048), dtype="float16") = lv402[0]
            lv101 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_29_self_attn_c_attn_q_weight5, model_layers_29_self_attn_c_attn_q_scale5, rms_norm350, model_layers_29_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape692 = R.call_tir(cls.reshape4, (lv101,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape693 = R.call_tir(cls.reshape5, (reshape692,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv872 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(29), R.prim_value(T.float32(1.0)), reshape693), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape694 = R.call_tir(cls.reshape6, (lv872,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape695 = R.call_tir(cls.reshape7, (reshape694,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv305_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_29_self_attn_o_proj_q_weight5, model_layers_29_self_attn_o_proj_q_scale5, reshape695), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv404 = R.call_tir(cls.fuse_add_norm_prefill, (lv305_1, lv403, model_layers_29_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv405: R.Tensor((1, seq_len, 2048), dtype="float16") = lv404[1]
            rms_norm351: R.Tensor((1, seq_len, 2048), dtype="float16") = lv404[0]
            lv306_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_29_mlp_gate_up_proj_q_weight5, model_layers_29_mlp_gate_up_proj_q_scale5, rms_norm351), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv203 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv306_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv307_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_29_mlp_down_proj_q_weight5, model_layers_29_mlp_down_proj_q_scale5, lv203), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv406 = R.call_tir(cls.fuse_add_norm_prefill, (lv307_1, lv405, model_layers_30_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv407: R.Tensor((1, seq_len, 2048), dtype="float16") = lv406[1]
            rms_norm352: R.Tensor((1, seq_len, 2048), dtype="float16") = lv406[0]
            lv102 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_30_self_attn_c_attn_q_weight5, model_layers_30_self_attn_c_attn_q_scale5, rms_norm352, model_layers_30_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape696 = R.call_tir(cls.reshape4, (lv102,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape697 = R.call_tir(cls.reshape5, (reshape696,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv877 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(30), R.prim_value(T.float32(1.0)), reshape697), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape698 = R.call_tir(cls.reshape6, (lv877,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape699 = R.call_tir(cls.reshape7, (reshape698,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv308_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_30_self_attn_o_proj_q_weight5, model_layers_30_self_attn_o_proj_q_scale5, reshape699), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv408 = R.call_tir(cls.fuse_add_norm_prefill, (lv308_1, lv407, model_layers_30_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv409: R.Tensor((1, seq_len, 2048), dtype="float16") = lv408[1]
            rms_norm353: R.Tensor((1, seq_len, 2048), dtype="float16") = lv408[0]
            lv309_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_30_mlp_gate_up_proj_q_weight5, model_layers_30_mlp_gate_up_proj_q_scale5, rms_norm353), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv205 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv309_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv310_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_30_mlp_down_proj_q_weight5, model_layers_30_mlp_down_proj_q_scale5, lv205), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv410 = R.call_tir(cls.fuse_add_norm_prefill, (lv310_1, lv409, model_layers_31_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv411: R.Tensor((1, seq_len, 2048), dtype="float16") = lv410[1]
            rms_norm354: R.Tensor((1, seq_len, 2048), dtype="float16") = lv410[0]
            lv103 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_31_self_attn_c_attn_q_weight5, model_layers_31_self_attn_c_attn_q_scale5, rms_norm354, model_layers_31_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape700 = R.call_tir(cls.reshape4, (lv103,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape701 = R.call_tir(cls.reshape5, (reshape700,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv882 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(31), R.prim_value(T.float32(1.0)), reshape701), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape702 = R.call_tir(cls.reshape6, (lv882,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape703 = R.call_tir(cls.reshape7, (reshape702,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv311_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_31_self_attn_o_proj_q_weight5, model_layers_31_self_attn_o_proj_q_scale5, reshape703), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv412 = R.call_tir(cls.fuse_add_norm_prefill, (lv311_1, lv411, model_layers_31_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv413: R.Tensor((1, seq_len, 2048), dtype="float16") = lv412[1]
            rms_norm355: R.Tensor((1, seq_len, 2048), dtype="float16") = lv412[0]
            lv312_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_31_mlp_gate_up_proj_q_weight5, model_layers_31_mlp_gate_up_proj_q_scale5, rms_norm355), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv207 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv312_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv313_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_31_mlp_down_proj_q_weight5, model_layers_31_mlp_down_proj_q_scale5, lv207), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv414 = R.call_tir(cls.fuse_add_norm_prefill, (lv313_1, lv413, model_layers_32_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv415: R.Tensor((1, seq_len, 2048), dtype="float16") = lv414[1]
            rms_norm356: R.Tensor((1, seq_len, 2048), dtype="float16") = lv414[0]
            lv104 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_32_self_attn_c_attn_q_weight5, model_layers_32_self_attn_c_attn_q_scale5, rms_norm356, model_layers_32_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape704 = R.call_tir(cls.reshape4, (lv104,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape705 = R.call_tir(cls.reshape5, (reshape704,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv887 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(32), R.prim_value(T.float32(1.0)), reshape705), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape706 = R.call_tir(cls.reshape6, (lv887,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape707 = R.call_tir(cls.reshape7, (reshape706,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv314_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_32_self_attn_o_proj_q_weight5, model_layers_32_self_attn_o_proj_q_scale5, reshape707), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv416 = R.call_tir(cls.fuse_add_norm_prefill, (lv314_1, lv415, model_layers_32_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv417: R.Tensor((1, seq_len, 2048), dtype="float16") = lv416[1]
            rms_norm357: R.Tensor((1, seq_len, 2048), dtype="float16") = lv416[0]
            lv315_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_32_mlp_gate_up_proj_q_weight5, model_layers_32_mlp_gate_up_proj_q_scale5, rms_norm357), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv209 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv315_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv316_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_32_mlp_down_proj_q_weight5, model_layers_32_mlp_down_proj_q_scale5, lv209), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv418 = R.call_tir(cls.fuse_add_norm_prefill, (lv316_1, lv417, model_layers_33_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv419: R.Tensor((1, seq_len, 2048), dtype="float16") = lv418[1]
            rms_norm358: R.Tensor((1, seq_len, 2048), dtype="float16") = lv418[0]
            lv105 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_33_self_attn_c_attn_q_weight5, model_layers_33_self_attn_c_attn_q_scale5, rms_norm358, model_layers_33_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape708 = R.call_tir(cls.reshape4, (lv105,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape709 = R.call_tir(cls.reshape5, (reshape708,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv892 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(33), R.prim_value(T.float32(1.0)), reshape709), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape710 = R.call_tir(cls.reshape6, (lv892,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape711 = R.call_tir(cls.reshape7, (reshape710,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv317_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_33_self_attn_o_proj_q_weight5, model_layers_33_self_attn_o_proj_q_scale5, reshape711), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv420 = R.call_tir(cls.fuse_add_norm_prefill, (lv317_1, lv419, model_layers_33_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv421: R.Tensor((1, seq_len, 2048), dtype="float16") = lv420[1]
            rms_norm359: R.Tensor((1, seq_len, 2048), dtype="float16") = lv420[0]
            lv318_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_33_mlp_gate_up_proj_q_weight5, model_layers_33_mlp_gate_up_proj_q_scale5, rms_norm359), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv211 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv318_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv319_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_33_mlp_down_proj_q_weight5, model_layers_33_mlp_down_proj_q_scale5, lv211), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv422 = R.call_tir(cls.fuse_add_norm_prefill, (lv319_1, lv421, model_layers_34_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv423: R.Tensor((1, seq_len, 2048), dtype="float16") = lv422[1]
            rms_norm360: R.Tensor((1, seq_len, 2048), dtype="float16") = lv422[0]
            lv106 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_34_self_attn_c_attn_q_weight5, model_layers_34_self_attn_c_attn_q_scale5, rms_norm360, model_layers_34_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape712 = R.call_tir(cls.reshape4, (lv106,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape713 = R.call_tir(cls.reshape5, (reshape712,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv897 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(34), R.prim_value(T.float32(1.0)), reshape713), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape714 = R.call_tir(cls.reshape6, (lv897,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape715 = R.call_tir(cls.reshape7, (reshape714,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv320_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_34_self_attn_o_proj_q_weight5, model_layers_34_self_attn_o_proj_q_scale5, reshape715), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv424 = R.call_tir(cls.fuse_add_norm_prefill, (lv320_1, lv423, model_layers_34_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv425: R.Tensor((1, seq_len, 2048), dtype="float16") = lv424[1]
            rms_norm361: R.Tensor((1, seq_len, 2048), dtype="float16") = lv424[0]
            lv321_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_34_mlp_gate_up_proj_q_weight5, model_layers_34_mlp_gate_up_proj_q_scale5, rms_norm361), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv213 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv321_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv322_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_34_mlp_down_proj_q_weight5, model_layers_34_mlp_down_proj_q_scale5, lv213), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv426 = R.call_tir(cls.fuse_add_norm_prefill, (lv322_1, lv425, model_layers_35_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv427: R.Tensor((1, seq_len, 2048), dtype="float16") = lv426[1]
            rms_norm362: R.Tensor((1, seq_len, 2048), dtype="float16") = lv426[0]
            lv107 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul5_add1, (model_layers_35_self_attn_c_attn_q_weight5, model_layers_35_self_attn_c_attn_q_scale5, rms_norm362, model_layers_35_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 2560), dtype="float16"))
            reshape716 = R.call_tir(cls.reshape4, (lv107,), out_sinfo=R.Tensor((1, seq_len, 20, 128), dtype="float16"))
            reshape717 = R.call_tir(cls.reshape5, (reshape716,), out_sinfo=R.Tensor((seq_len, 20, 128), dtype="float16"))
            lv902 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(35), R.prim_value(T.float32(1.0)), reshape717), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape718 = R.call_tir(cls.reshape6, (lv902,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape719 = R.call_tir(cls.reshape7, (reshape718,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv323_1 = R.call_tir(cls.fused_dequantize2_NT_matmul6, (model_layers_35_self_attn_o_proj_q_weight5, model_layers_35_self_attn_o_proj_q_scale5, reshape719), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv428 = R.call_tir(cls.fuse_add_norm_prefill, (lv323_1, lv427, model_layers_35_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv429: R.Tensor((1, seq_len, 2048), dtype="float16") = lv428[1]
            rms_norm363: R.Tensor((1, seq_len, 2048), dtype="float16") = lv428[0]
            lv324_1 = R.call_tir(cls.fused_dequantize3_NT_matmul7, (model_layers_35_mlp_gate_up_proj_q_weight5, model_layers_35_mlp_gate_up_proj_q_scale5, rms_norm363), out_sinfo=R.Tensor((1, seq_len, 22016), dtype="float16"))
            lv215 = R.call_tir(cls.fused_split1_silu1_multiply1, (lv324_1,), out_sinfo=R.Tensor((1, seq_len, 11008), dtype="float16"))
            lv325_1 = R.call_tir(cls.fused_dequantize4_NT_matmul8, (model_layers_35_mlp_down_proj_q_weight5, model_layers_35_mlp_down_proj_q_scale5, lv215), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv430 = R.call_tir(cls.fuse_add_norm_prefill, (lv325_1, lv429, model_norm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            rms_norm364: R.Tensor((1, seq_len, 2048), dtype="float16") = lv430[0]
            lv326_1 = R.call_tir(cls.fused_dequantize_NT_matmul9, (model_embed_tokens_q_weight5, model_embed_tokens_q_scale5, rms_norm364), out_sinfo=R.Tensor((1, seq_len, 151936), dtype="float32"))
            gv5: R.Tuple(R.Tensor((1, seq_len, 151936), dtype="float32"), R.Object) = lv326_1, paged_kv_cache
            R.output(gv5)
        return gv5

    @R.function
    def create_flashinfer_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object:
        max_batch_size = T.int64()
        max_total_seq_len = T.int64()
        prefill_chunk_size = T.int64()
        page_size = T.int64()
        support_sliding_window = T.int64()
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        paged_kv_cache1: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 36]), R.prim_value(16), R.prim_value(2), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(T.float32(1000000.0)), R.const(0.0, "float16"), cls.tir_kv_cache_transpose_append, R.ExternFunc("flashinfer.attention_kernel_prefill_with_paged_kv_cache"), R.ExternFunc("flashinfer.attention_kernel_decode_with_paged_kv_cache"), cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_paged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_prefill_with_paged_kv_cache_end_forward"), R.ExternFunc("flashinfer.attention_kernel_decode_with_paged_kv_cache_begin_forward"), R.ExternFunc("flashinfer.attention_kernel_decode_with_paged_kv_cache_end_forward"), R.ExternFunc("flashinfer.merge_state_in_place"), cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, R.prim_value(0), R.prim_value(0), sinfo_args=(R.Object,))
        return paged_kv_cache1

    @R.function
    def create_tir_paged_kv_cache(max_batch_size_: R.Shape(["max_batch_size"]), max_total_seq_len_: R.Shape(["max_total_seq_len"]), prefill_chunk_size_: R.Shape(["prefill_chunk_size"]), page_size_: R.Shape(["page_size"]), support_sliding_window_: R.Shape(["support_sliding_window"])) -> R.Object:
        max_batch_size = T.int64()
        max_total_seq_len = T.int64()
        prefill_chunk_size = T.int64()
        page_size = T.int64()
        support_sliding_window = T.int64()
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        paged_kv_cache: R.Object = R.call_pure_packed("vm.builtin.paged_attention_kv_cache_create_reduced", R.shape([max_batch_size, max_total_seq_len, prefill_chunk_size, page_size, support_sliding_window]), R.shape([0, 36]), R.prim_value(16), R.prim_value(2), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(T.float32(1000000.0)), R.const(0.0, "float16"), cls.tir_kv_cache_transpose_append, cls.batch_prefill_paged_kv, cls.batch_decode_paged_kv, cls.batch_prefill_paged_kv_sliding_window, cls.batch_decode_paged_kv_sliding_window, cls.batch_prefill_ragged_kv, cls.merge_state_inplace, cls.fused_rope, cls.copy_single_page, cls.tir_kv_cache_debug_get_kv, cls.compact_kv_copy, cls.batch_tree_attn, cls.tree_attn_paged_kv, R.prim_value(0), R.prim_value(0), sinfo_args=(R.Object,))
        return paged_kv_cache

    @R.function
    def decode(input_embed: R.Tensor((1, 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2560, 256), dtype="uint32"), R.Tensor((2560, 64), dtype="float16"), R.Tensor((2560,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((22016, 256), dtype="uint32"), R.Tensor((22016, 64), dtype="float16"), R.Tensor((2048, 1376), dtype="uint32"), R.Tensor((2048, 344), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"))) -> R.Tuple(R.Tensor((1, 1, 151936), dtype="float32"), R.Object):
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 2048, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_embed_tokens_q_weight2: R.Tensor((151936, 256), dtype="uint32") = packed_params[0]
            model_embed_tokens_q_scale2: R.Tensor((151936, 64), dtype="float16") = packed_params[1]
            model_layers_0_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[10]
            model_layers_0_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[11]
            model_layers_0_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[12]
            model_layers_1_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[13]
            model_layers_1_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[14]
            model_layers_1_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[15]
            model_layers_1_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[16]
            model_layers_1_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[17]
            model_layers_1_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[18]
            model_layers_1_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[19]
            model_layers_1_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[20]
            model_layers_1_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[21]
            model_layers_1_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[22]
            model_layers_1_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[23]
            model_layers_2_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[24]
            model_layers_2_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[25]
            model_layers_2_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[26]
            model_layers_2_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[27]
            model_layers_2_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[28]
            model_layers_2_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[29]
            model_layers_2_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[30]
            model_layers_2_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[31]
            model_layers_2_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[32]
            model_layers_2_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[33]
            model_layers_2_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_3_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[35]
            model_layers_3_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[36]
            model_layers_3_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[37]
            model_layers_3_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[38]
            model_layers_3_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[39]
            model_layers_3_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[40]
            model_layers_3_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[41]
            model_layers_3_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[42]
            model_layers_3_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[43]
            model_layers_3_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[44]
            model_layers_3_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[45]
            model_layers_4_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[46]
            model_layers_4_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[47]
            model_layers_4_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[48]
            model_layers_4_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[49]
            model_layers_4_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[50]
            model_layers_4_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[51]
            model_layers_4_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[52]
            model_layers_4_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[53]
            model_layers_4_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[54]
            model_layers_4_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[55]
            model_layers_4_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[56]
            model_layers_5_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[57]
            model_layers_5_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[58]
            model_layers_5_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[59]
            model_layers_5_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[60]
            model_layers_5_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[61]
            model_layers_5_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[62]
            model_layers_5_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[63]
            model_layers_5_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[64]
            model_layers_5_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[65]
            model_layers_5_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[66]
            model_layers_5_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[67]
            model_layers_6_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[68]
            model_layers_6_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[69]
            model_layers_6_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[70]
            model_layers_6_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[71]
            model_layers_6_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[72]
            model_layers_6_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[73]
            model_layers_6_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[74]
            model_layers_6_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[75]
            model_layers_6_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[76]
            model_layers_6_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[77]
            model_layers_6_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[78]
            model_layers_7_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[79]
            model_layers_7_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[80]
            model_layers_7_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[81]
            model_layers_7_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[82]
            model_layers_7_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[83]
            model_layers_7_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[84]
            model_layers_7_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[85]
            model_layers_7_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[86]
            model_layers_7_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[87]
            model_layers_7_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[88]
            model_layers_7_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[89]
            model_layers_8_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[90]
            model_layers_8_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[91]
            model_layers_8_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[92]
            model_layers_8_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[93]
            model_layers_8_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[94]
            model_layers_8_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[95]
            model_layers_8_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[96]
            model_layers_8_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[97]
            model_layers_8_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[98]
            model_layers_8_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[99]
            model_layers_8_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[100]
            model_layers_9_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[101]
            model_layers_9_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[102]
            model_layers_9_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[103]
            model_layers_9_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[104]
            model_layers_9_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[105]
            model_layers_9_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[106]
            model_layers_9_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[107]
            model_layers_9_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[108]
            model_layers_9_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[109]
            model_layers_9_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[110]
            model_layers_9_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[111]
            model_layers_10_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[112]
            model_layers_10_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[113]
            model_layers_10_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[114]
            model_layers_10_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[115]
            model_layers_10_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[116]
            model_layers_10_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[117]
            model_layers_10_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[118]
            model_layers_10_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[119]
            model_layers_10_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[120]
            model_layers_10_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[121]
            model_layers_10_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[122]
            model_layers_11_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[123]
            model_layers_11_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[124]
            model_layers_11_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[125]
            model_layers_11_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[126]
            model_layers_11_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[127]
            model_layers_11_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[128]
            model_layers_11_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[129]
            model_layers_11_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[130]
            model_layers_11_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[131]
            model_layers_11_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[132]
            model_layers_11_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[133]
            model_layers_12_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[134]
            model_layers_12_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[135]
            model_layers_12_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[136]
            model_layers_12_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[137]
            model_layers_12_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[138]
            model_layers_12_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[139]
            model_layers_12_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[140]
            model_layers_12_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[141]
            model_layers_12_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[142]
            model_layers_12_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[143]
            model_layers_12_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[144]
            model_layers_13_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[145]
            model_layers_13_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[146]
            model_layers_13_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[147]
            model_layers_13_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[148]
            model_layers_13_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[149]
            model_layers_13_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[150]
            model_layers_13_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[151]
            model_layers_13_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[152]
            model_layers_13_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[153]
            model_layers_13_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_13_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[155]
            model_layers_14_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[156]
            model_layers_14_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[157]
            model_layers_14_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[158]
            model_layers_14_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[159]
            model_layers_14_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[160]
            model_layers_14_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[161]
            model_layers_14_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[162]
            model_layers_14_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[163]
            model_layers_14_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[164]
            model_layers_14_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[165]
            model_layers_14_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[166]
            model_layers_15_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[167]
            model_layers_15_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[168]
            model_layers_15_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[169]
            model_layers_15_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[170]
            model_layers_15_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[171]
            model_layers_15_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[172]
            model_layers_15_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[173]
            model_layers_15_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[174]
            model_layers_15_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[175]
            model_layers_15_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[176]
            model_layers_15_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[177]
            model_layers_16_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[178]
            model_layers_16_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[179]
            model_layers_16_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[180]
            model_layers_16_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[181]
            model_layers_16_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[182]
            model_layers_16_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[183]
            model_layers_16_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[184]
            model_layers_16_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[185]
            model_layers_16_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[186]
            model_layers_16_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_16_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_17_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[189]
            model_layers_17_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[190]
            model_layers_17_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[191]
            model_layers_17_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_17_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_17_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[194]
            model_layers_17_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[195]
            model_layers_17_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[196]
            model_layers_17_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[197]
            model_layers_17_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[198]
            model_layers_17_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[199]
            model_layers_18_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[200]
            model_layers_18_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[201]
            model_layers_18_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[202]
            model_layers_18_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[203]
            model_layers_18_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[204]
            model_layers_18_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[205]
            model_layers_18_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[206]
            model_layers_18_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[207]
            model_layers_18_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[208]
            model_layers_18_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[209]
            model_layers_18_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[210]
            model_layers_19_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[211]
            model_layers_19_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[212]
            model_layers_19_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[213]
            model_layers_19_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[214]
            model_layers_19_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[215]
            model_layers_19_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[216]
            model_layers_19_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[217]
            model_layers_19_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[218]
            model_layers_19_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[219]
            model_layers_19_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[220]
            model_layers_19_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_20_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[222]
            model_layers_20_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[223]
            model_layers_20_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[224]
            model_layers_20_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[225]
            model_layers_20_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[226]
            model_layers_20_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[227]
            model_layers_20_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[228]
            model_layers_20_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[229]
            model_layers_20_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[230]
            model_layers_20_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[231]
            model_layers_20_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[232]
            model_layers_21_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[233]
            model_layers_21_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[234]
            model_layers_21_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[235]
            model_layers_21_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[236]
            model_layers_21_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[237]
            model_layers_21_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[238]
            model_layers_21_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[239]
            model_layers_21_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[240]
            model_layers_21_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[241]
            model_layers_21_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[242]
            model_layers_21_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[243]
            model_layers_22_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[244]
            model_layers_22_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[245]
            model_layers_22_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[246]
            model_layers_22_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[247]
            model_layers_22_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[248]
            model_layers_22_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[249]
            model_layers_22_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[250]
            model_layers_22_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[251]
            model_layers_22_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[252]
            model_layers_22_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[253]
            model_layers_22_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[254]
            model_layers_23_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[255]
            model_layers_23_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[256]
            model_layers_23_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[257]
            model_layers_23_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[258]
            model_layers_23_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[259]
            model_layers_23_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[260]
            model_layers_23_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[261]
            model_layers_23_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[262]
            model_layers_23_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[263]
            model_layers_23_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[264]
            model_layers_23_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[265]
            model_layers_24_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[266]
            model_layers_24_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[267]
            model_layers_24_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[268]
            model_layers_24_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[269]
            model_layers_24_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[270]
            model_layers_24_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[271]
            model_layers_24_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[272]
            model_layers_24_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[273]
            model_layers_24_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[274]
            model_layers_24_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[275]
            model_layers_24_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[276]
            model_layers_25_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[277]
            model_layers_25_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[278]
            model_layers_25_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[279]
            model_layers_25_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[280]
            model_layers_25_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[281]
            model_layers_25_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[282]
            model_layers_25_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[283]
            model_layers_25_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[284]
            model_layers_25_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[285]
            model_layers_25_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[286]
            model_layers_25_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[287]
            model_layers_26_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[288]
            model_layers_26_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[289]
            model_layers_26_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[290]
            model_layers_26_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[291]
            model_layers_26_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[292]
            model_layers_26_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[293]
            model_layers_26_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[294]
            model_layers_26_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[295]
            model_layers_26_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[296]
            model_layers_26_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[297]
            model_layers_26_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[298]
            model_layers_27_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[299]
            model_layers_27_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[300]
            model_layers_27_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[301]
            model_layers_27_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[302]
            model_layers_27_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[303]
            model_layers_27_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[304]
            model_layers_27_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[305]
            model_layers_27_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[306]
            model_layers_27_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[307]
            model_layers_27_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[308]
            model_layers_27_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[309]
            model_layers_28_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[310]
            model_layers_28_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[311]
            model_layers_28_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[312]
            model_layers_28_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[313]
            model_layers_28_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[314]
            model_layers_28_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[315]
            model_layers_28_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[316]
            model_layers_28_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[317]
            model_layers_28_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[318]
            model_layers_28_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[319]
            model_layers_28_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[320]
            model_layers_29_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[321]
            model_layers_29_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[322]
            model_layers_29_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[323]
            model_layers_29_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[324]
            model_layers_29_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[325]
            model_layers_29_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[326]
            model_layers_29_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[327]
            model_layers_29_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[328]
            model_layers_29_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[329]
            model_layers_29_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[330]
            model_layers_29_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[331]
            model_layers_30_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[332]
            model_layers_30_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[333]
            model_layers_30_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[334]
            model_layers_30_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[335]
            model_layers_30_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[336]
            model_layers_30_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[337]
            model_layers_30_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[338]
            model_layers_30_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[339]
            model_layers_30_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[340]
            model_layers_30_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_30_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[342]
            model_layers_31_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[343]
            model_layers_31_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[344]
            model_layers_31_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[345]
            model_layers_31_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[346]
            model_layers_31_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[347]
            model_layers_31_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[348]
            model_layers_31_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[349]
            model_layers_31_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[350]
            model_layers_31_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[351]
            model_layers_31_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[352]
            model_layers_31_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[353]
            model_layers_32_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[354]
            model_layers_32_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[355]
            model_layers_32_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[356]
            model_layers_32_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[357]
            model_layers_32_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[358]
            model_layers_32_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[359]
            model_layers_32_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[360]
            model_layers_32_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[361]
            model_layers_32_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[362]
            model_layers_32_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[363]
            model_layers_32_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[364]
            model_layers_33_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[365]
            model_layers_33_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[366]
            model_layers_33_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[367]
            model_layers_33_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[368]
            model_layers_33_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[369]
            model_layers_33_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[370]
            model_layers_33_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[371]
            model_layers_33_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[372]
            model_layers_33_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[373]
            model_layers_33_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_33_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_34_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[376]
            model_layers_34_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[377]
            model_layers_34_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[378]
            model_layers_34_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_34_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_34_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[381]
            model_layers_34_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[382]
            model_layers_34_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[383]
            model_layers_34_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[384]
            model_layers_34_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[385]
            model_layers_34_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[386]
            model_layers_35_self_attn_c_attn_q_weight2: R.Tensor((2560, 256), dtype="uint32") = packed_params[387]
            model_layers_35_self_attn_c_attn_q_scale2: R.Tensor((2560, 64), dtype="float16") = packed_params[388]
            model_layers_35_self_attn_c_attn_bias2: R.Tensor((2560,), dtype="float16") = packed_params[389]
            model_layers_35_self_attn_o_proj_q_weight2: R.Tensor((2048, 256), dtype="uint32") = packed_params[390]
            model_layers_35_self_attn_o_proj_q_scale2: R.Tensor((2048, 64), dtype="float16") = packed_params[391]
            model_layers_35_mlp_gate_up_proj_q_weight2: R.Tensor((22016, 256), dtype="uint32") = packed_params[392]
            model_layers_35_mlp_gate_up_proj_q_scale2: R.Tensor((22016, 64), dtype="float16") = packed_params[393]
            model_layers_35_mlp_down_proj_q_weight2: R.Tensor((2048, 1376), dtype="uint32") = packed_params[394]
            model_layers_35_mlp_down_proj_q_scale2: R.Tensor((2048, 344), dtype="float16") = packed_params[395]
            model_layers_35_input_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[396]
            model_layers_35_post_attention_layernorm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[397]
            model_norm_weight2: R.Tensor((2048,), dtype="float16") = packed_params[398]
            rms_norm73 = R.call_tir(cls.rms_norm2, (input_embed, model_layers_0_input_layernorm_weight2), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv108 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_0_self_attn_c_attn_q_weight2, model_layers_0_self_attn_c_attn_q_scale2, rms_norm73, model_layers_0_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv217 = R.call_tir(cls.fused_reshape8_reshape9, (lv108,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv184 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), lv217), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv218 = R.call_tir(cls.fused_reshape10_reshape11, (lv184,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv327 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_0_self_attn_o_proj_q_weight2, model_layers_0_self_attn_o_proj_q_scale2, lv218), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv432 = R.call_tir(cls.fuse_add_norm_prefill, (lv327, input_embed, model_layers_0_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv433: R.Tensor((1, 1, 2048), dtype="float16") = lv432[1]
            rms_norm74: R.Tensor((1, 1, 2048), dtype="float16") = lv432[0]
            lv328 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_0_mlp_gate_up_proj_q_weight2, model_layers_0_mlp_gate_up_proj_q_scale2, rms_norm74), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv219 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv328,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv329 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_0_mlp_down_proj_q_weight2, model_layers_0_mlp_down_proj_q_scale2, lv219), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv434 = R.call_tir(cls.fuse_add_norm_prefill, (lv329, lv433, model_layers_1_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv435: R.Tensor((1, 1, 2048), dtype="float16") = lv434[1]
            rms_norm75: R.Tensor((1, 1, 2048), dtype="float16") = lv434[0]
            lv109 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_1_self_attn_c_attn_q_weight2, model_layers_1_self_attn_c_attn_q_scale2, rms_norm75, model_layers_1_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv221 = R.call_tir(cls.fused_reshape8_reshape9, (lv109,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv189 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), lv221), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv222 = R.call_tir(cls.fused_reshape10_reshape11, (lv189,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv330 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_1_self_attn_o_proj_q_weight2, model_layers_1_self_attn_o_proj_q_scale2, lv222), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv436 = R.call_tir(cls.fuse_add_norm_prefill, (lv330, lv435, model_layers_1_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv437: R.Tensor((1, 1, 2048), dtype="float16") = lv436[1]
            rms_norm76: R.Tensor((1, 1, 2048), dtype="float16") = lv436[0]
            lv331 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_1_mlp_gate_up_proj_q_weight2, model_layers_1_mlp_gate_up_proj_q_scale2, rms_norm76), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv223 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv331,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv332 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_1_mlp_down_proj_q_weight2, model_layers_1_mlp_down_proj_q_scale2, lv223), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv438 = R.call_tir(cls.fuse_add_norm_prefill, (lv332, lv437, model_layers_2_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv439: R.Tensor((1, 1, 2048), dtype="float16") = lv438[1]
            rms_norm77: R.Tensor((1, 1, 2048), dtype="float16") = lv438[0]
            lv110 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_2_self_attn_c_attn_q_weight2, model_layers_2_self_attn_c_attn_q_scale2, rms_norm77, model_layers_2_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv225 = R.call_tir(cls.fused_reshape8_reshape9, (lv110,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv194 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), lv225), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv226 = R.call_tir(cls.fused_reshape10_reshape11, (lv194,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv333 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_2_self_attn_o_proj_q_weight2, model_layers_2_self_attn_o_proj_q_scale2, lv226), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv440 = R.call_tir(cls.fuse_add_norm_prefill, (lv333, lv439, model_layers_2_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv441: R.Tensor((1, 1, 2048), dtype="float16") = lv440[1]
            rms_norm78: R.Tensor((1, 1, 2048), dtype="float16") = lv440[0]
            lv334 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_2_mlp_gate_up_proj_q_weight2, model_layers_2_mlp_gate_up_proj_q_scale2, rms_norm78), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv227 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv334,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv335 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_2_mlp_down_proj_q_weight2, model_layers_2_mlp_down_proj_q_scale2, lv227), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv442 = R.call_tir(cls.fuse_add_norm_prefill, (lv335, lv441, model_layers_3_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv443: R.Tensor((1, 1, 2048), dtype="float16") = lv442[1]
            rms_norm79: R.Tensor((1, 1, 2048), dtype="float16") = lv442[0]
            lv111 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_3_self_attn_c_attn_q_weight2, model_layers_3_self_attn_c_attn_q_scale2, rms_norm79, model_layers_3_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv229 = R.call_tir(cls.fused_reshape8_reshape9, (lv111,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv199 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), lv229), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv230 = R.call_tir(cls.fused_reshape10_reshape11, (lv199,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv336 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_3_self_attn_o_proj_q_weight2, model_layers_3_self_attn_o_proj_q_scale2, lv230), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv444 = R.call_tir(cls.fuse_add_norm_prefill, (lv336, lv443, model_layers_3_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv445: R.Tensor((1, 1, 2048), dtype="float16") = lv444[1]
            rms_norm80: R.Tensor((1, 1, 2048), dtype="float16") = lv444[0]
            lv337 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_3_mlp_gate_up_proj_q_weight2, model_layers_3_mlp_gate_up_proj_q_scale2, rms_norm80), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv231 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv337,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv338 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_3_mlp_down_proj_q_weight2, model_layers_3_mlp_down_proj_q_scale2, lv231), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv446 = R.call_tir(cls.fuse_add_norm_prefill, (lv338, lv445, model_layers_4_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv447: R.Tensor((1, 1, 2048), dtype="float16") = lv446[1]
            rms_norm81: R.Tensor((1, 1, 2048), dtype="float16") = lv446[0]
            lv112 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_4_self_attn_c_attn_q_weight2, model_layers_4_self_attn_c_attn_q_scale2, rms_norm81, model_layers_4_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv233 = R.call_tir(cls.fused_reshape8_reshape9, (lv112,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv204 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), lv233), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv234 = R.call_tir(cls.fused_reshape10_reshape11, (lv204,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv339 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_4_self_attn_o_proj_q_weight2, model_layers_4_self_attn_o_proj_q_scale2, lv234), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv448 = R.call_tir(cls.fuse_add_norm_prefill, (lv339, lv447, model_layers_4_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv449: R.Tensor((1, 1, 2048), dtype="float16") = lv448[1]
            rms_norm82: R.Tensor((1, 1, 2048), dtype="float16") = lv448[0]
            lv340 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_4_mlp_gate_up_proj_q_weight2, model_layers_4_mlp_gate_up_proj_q_scale2, rms_norm82), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv235 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv340,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv341 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_4_mlp_down_proj_q_weight2, model_layers_4_mlp_down_proj_q_scale2, lv235), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv450 = R.call_tir(cls.fuse_add_norm_prefill, (lv341, lv449, model_layers_5_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv451: R.Tensor((1, 1, 2048), dtype="float16") = lv450[1]
            rms_norm83: R.Tensor((1, 1, 2048), dtype="float16") = lv450[0]
            lv113 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_5_self_attn_c_attn_q_weight2, model_layers_5_self_attn_c_attn_q_scale2, rms_norm83, model_layers_5_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv237 = R.call_tir(cls.fused_reshape8_reshape9, (lv113,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv209 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), lv237), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv238 = R.call_tir(cls.fused_reshape10_reshape11, (lv209,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv342 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_5_self_attn_o_proj_q_weight2, model_layers_5_self_attn_o_proj_q_scale2, lv238), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv452 = R.call_tir(cls.fuse_add_norm_prefill, (lv342, lv451, model_layers_5_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv453: R.Tensor((1, 1, 2048), dtype="float16") = lv452[1]
            rms_norm84: R.Tensor((1, 1, 2048), dtype="float16") = lv452[0]
            lv343 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_5_mlp_gate_up_proj_q_weight2, model_layers_5_mlp_gate_up_proj_q_scale2, rms_norm84), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv239 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv343,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv344 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_5_mlp_down_proj_q_weight2, model_layers_5_mlp_down_proj_q_scale2, lv239), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv454 = R.call_tir(cls.fuse_add_norm_prefill, (lv344, lv453, model_layers_6_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv455: R.Tensor((1, 1, 2048), dtype="float16") = lv454[1]
            rms_norm85: R.Tensor((1, 1, 2048), dtype="float16") = lv454[0]
            lv114 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_6_self_attn_c_attn_q_weight2, model_layers_6_self_attn_c_attn_q_scale2, rms_norm85, model_layers_6_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv241 = R.call_tir(cls.fused_reshape8_reshape9, (lv114,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv214 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), lv241), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv242 = R.call_tir(cls.fused_reshape10_reshape11, (lv214,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv345 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_6_self_attn_o_proj_q_weight2, model_layers_6_self_attn_o_proj_q_scale2, lv242), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv456 = R.call_tir(cls.fuse_add_norm_prefill, (lv345, lv455, model_layers_6_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv457: R.Tensor((1, 1, 2048), dtype="float16") = lv456[1]
            rms_norm86: R.Tensor((1, 1, 2048), dtype="float16") = lv456[0]
            lv346 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_6_mlp_gate_up_proj_q_weight2, model_layers_6_mlp_gate_up_proj_q_scale2, rms_norm86), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv243 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv346,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv347 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_6_mlp_down_proj_q_weight2, model_layers_6_mlp_down_proj_q_scale2, lv243), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv458 = R.call_tir(cls.fuse_add_norm_prefill, (lv347, lv457, model_layers_7_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv459: R.Tensor((1, 1, 2048), dtype="float16") = lv458[1]
            rms_norm87: R.Tensor((1, 1, 2048), dtype="float16") = lv458[0]
            lv115 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_7_self_attn_c_attn_q_weight2, model_layers_7_self_attn_c_attn_q_scale2, rms_norm87, model_layers_7_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv245 = R.call_tir(cls.fused_reshape8_reshape9, (lv115,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv219_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), lv245), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv246 = R.call_tir(cls.fused_reshape10_reshape11, (lv219_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv348 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_7_self_attn_o_proj_q_weight2, model_layers_7_self_attn_o_proj_q_scale2, lv246), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv460 = R.call_tir(cls.fuse_add_norm_prefill, (lv348, lv459, model_layers_7_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv461: R.Tensor((1, 1, 2048), dtype="float16") = lv460[1]
            rms_norm88: R.Tensor((1, 1, 2048), dtype="float16") = lv460[0]
            lv349 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_7_mlp_gate_up_proj_q_weight2, model_layers_7_mlp_gate_up_proj_q_scale2, rms_norm88), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv247 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv349,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv350 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_7_mlp_down_proj_q_weight2, model_layers_7_mlp_down_proj_q_scale2, lv247), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv462 = R.call_tir(cls.fuse_add_norm_prefill, (lv350, lv461, model_layers_8_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv463: R.Tensor((1, 1, 2048), dtype="float16") = lv462[1]
            rms_norm89: R.Tensor((1, 1, 2048), dtype="float16") = lv462[0]
            lv116 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_8_self_attn_c_attn_q_weight2, model_layers_8_self_attn_c_attn_q_scale2, rms_norm89, model_layers_8_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv249 = R.call_tir(cls.fused_reshape8_reshape9, (lv116,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv224 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), lv249), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv250 = R.call_tir(cls.fused_reshape10_reshape11, (lv224,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv351 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_8_self_attn_o_proj_q_weight2, model_layers_8_self_attn_o_proj_q_scale2, lv250), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv464 = R.call_tir(cls.fuse_add_norm_prefill, (lv351, lv463, model_layers_8_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv465: R.Tensor((1, 1, 2048), dtype="float16") = lv464[1]
            rms_norm90: R.Tensor((1, 1, 2048), dtype="float16") = lv464[0]
            lv352 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_8_mlp_gate_up_proj_q_weight2, model_layers_8_mlp_gate_up_proj_q_scale2, rms_norm90), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv251 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv352,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv353 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_8_mlp_down_proj_q_weight2, model_layers_8_mlp_down_proj_q_scale2, lv251), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv466 = R.call_tir(cls.fuse_add_norm_prefill, (lv353, lv465, model_layers_9_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv467: R.Tensor((1, 1, 2048), dtype="float16") = lv466[1]
            rms_norm91: R.Tensor((1, 1, 2048), dtype="float16") = lv466[0]
            lv117 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_9_self_attn_c_attn_q_weight2, model_layers_9_self_attn_c_attn_q_scale2, rms_norm91, model_layers_9_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv253 = R.call_tir(cls.fused_reshape8_reshape9, (lv117,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv229_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), lv253), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv254 = R.call_tir(cls.fused_reshape10_reshape11, (lv229_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv354 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_9_self_attn_o_proj_q_weight2, model_layers_9_self_attn_o_proj_q_scale2, lv254), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv468 = R.call_tir(cls.fuse_add_norm_prefill, (lv354, lv467, model_layers_9_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv469: R.Tensor((1, 1, 2048), dtype="float16") = lv468[1]
            rms_norm92: R.Tensor((1, 1, 2048), dtype="float16") = lv468[0]
            lv355 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_9_mlp_gate_up_proj_q_weight2, model_layers_9_mlp_gate_up_proj_q_scale2, rms_norm92), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv255 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv355,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv356 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_9_mlp_down_proj_q_weight2, model_layers_9_mlp_down_proj_q_scale2, lv255), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv470 = R.call_tir(cls.fuse_add_norm_prefill, (lv356, lv469, model_layers_10_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv471: R.Tensor((1, 1, 2048), dtype="float16") = lv470[1]
            rms_norm93: R.Tensor((1, 1, 2048), dtype="float16") = lv470[0]
            lv118 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_10_self_attn_c_attn_q_weight2, model_layers_10_self_attn_c_attn_q_scale2, rms_norm93, model_layers_10_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv257 = R.call_tir(cls.fused_reshape8_reshape9, (lv118,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv234_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), lv257), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv258 = R.call_tir(cls.fused_reshape10_reshape11, (lv234_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv357 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_10_self_attn_o_proj_q_weight2, model_layers_10_self_attn_o_proj_q_scale2, lv258), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv472 = R.call_tir(cls.fuse_add_norm_prefill, (lv357, lv471, model_layers_10_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv473: R.Tensor((1, 1, 2048), dtype="float16") = lv472[1]
            rms_norm94: R.Tensor((1, 1, 2048), dtype="float16") = lv472[0]
            lv358 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_10_mlp_gate_up_proj_q_weight2, model_layers_10_mlp_gate_up_proj_q_scale2, rms_norm94), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv259 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv358,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv359 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_10_mlp_down_proj_q_weight2, model_layers_10_mlp_down_proj_q_scale2, lv259), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv474 = R.call_tir(cls.fuse_add_norm_prefill, (lv359, lv473, model_layers_11_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv475: R.Tensor((1, 1, 2048), dtype="float16") = lv474[1]
            rms_norm95: R.Tensor((1, 1, 2048), dtype="float16") = lv474[0]
            lv119 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_11_self_attn_c_attn_q_weight2, model_layers_11_self_attn_c_attn_q_scale2, rms_norm95, model_layers_11_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv261 = R.call_tir(cls.fused_reshape8_reshape9, (lv119,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv239_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), lv261), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv262 = R.call_tir(cls.fused_reshape10_reshape11, (lv239_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv360 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_11_self_attn_o_proj_q_weight2, model_layers_11_self_attn_o_proj_q_scale2, lv262), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv476 = R.call_tir(cls.fuse_add_norm_prefill, (lv360, lv475, model_layers_11_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv477: R.Tensor((1, 1, 2048), dtype="float16") = lv476[1]
            rms_norm96: R.Tensor((1, 1, 2048), dtype="float16") = lv476[0]
            lv361 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_11_mlp_gate_up_proj_q_weight2, model_layers_11_mlp_gate_up_proj_q_scale2, rms_norm96), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv263 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv361,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv362 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_11_mlp_down_proj_q_weight2, model_layers_11_mlp_down_proj_q_scale2, lv263), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv478 = R.call_tir(cls.fuse_add_norm_prefill, (lv362, lv477, model_layers_12_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv479: R.Tensor((1, 1, 2048), dtype="float16") = lv478[1]
            rms_norm97: R.Tensor((1, 1, 2048), dtype="float16") = lv478[0]
            lv120 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_12_self_attn_c_attn_q_weight2, model_layers_12_self_attn_c_attn_q_scale2, rms_norm97, model_layers_12_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv265 = R.call_tir(cls.fused_reshape8_reshape9, (lv120,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv244 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), lv265), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv266 = R.call_tir(cls.fused_reshape10_reshape11, (lv244,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv363 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_12_self_attn_o_proj_q_weight2, model_layers_12_self_attn_o_proj_q_scale2, lv266), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv480 = R.call_tir(cls.fuse_add_norm_prefill, (lv363, lv479, model_layers_12_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv481: R.Tensor((1, 1, 2048), dtype="float16") = lv480[1]
            rms_norm98: R.Tensor((1, 1, 2048), dtype="float16") = lv480[0]
            lv364 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_12_mlp_gate_up_proj_q_weight2, model_layers_12_mlp_gate_up_proj_q_scale2, rms_norm98), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv267 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv364,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv365 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_12_mlp_down_proj_q_weight2, model_layers_12_mlp_down_proj_q_scale2, lv267), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv482 = R.call_tir(cls.fuse_add_norm_prefill, (lv365, lv481, model_layers_13_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv483: R.Tensor((1, 1, 2048), dtype="float16") = lv482[1]
            rms_norm99: R.Tensor((1, 1, 2048), dtype="float16") = lv482[0]
            lv121 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_13_self_attn_c_attn_q_weight2, model_layers_13_self_attn_c_attn_q_scale2, rms_norm99, model_layers_13_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv269 = R.call_tir(cls.fused_reshape8_reshape9, (lv121,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv249_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), lv269), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv270 = R.call_tir(cls.fused_reshape10_reshape11, (lv249_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv366 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_13_self_attn_o_proj_q_weight2, model_layers_13_self_attn_o_proj_q_scale2, lv270), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv484 = R.call_tir(cls.fuse_add_norm_prefill, (lv366, lv483, model_layers_13_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv485: R.Tensor((1, 1, 2048), dtype="float16") = lv484[1]
            rms_norm100: R.Tensor((1, 1, 2048), dtype="float16") = lv484[0]
            lv367 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_13_mlp_gate_up_proj_q_weight2, model_layers_13_mlp_gate_up_proj_q_scale2, rms_norm100), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv271 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv367,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv368 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_13_mlp_down_proj_q_weight2, model_layers_13_mlp_down_proj_q_scale2, lv271), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv486 = R.call_tir(cls.fuse_add_norm_prefill, (lv368, lv485, model_layers_14_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv487: R.Tensor((1, 1, 2048), dtype="float16") = lv486[1]
            rms_norm101: R.Tensor((1, 1, 2048), dtype="float16") = lv486[0]
            lv122 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_14_self_attn_c_attn_q_weight2, model_layers_14_self_attn_c_attn_q_scale2, rms_norm101, model_layers_14_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv273 = R.call_tir(cls.fused_reshape8_reshape9, (lv122,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv254_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), lv273), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv274 = R.call_tir(cls.fused_reshape10_reshape11, (lv254_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv369 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_14_self_attn_o_proj_q_weight2, model_layers_14_self_attn_o_proj_q_scale2, lv274), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv488 = R.call_tir(cls.fuse_add_norm_prefill, (lv369, lv487, model_layers_14_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv489: R.Tensor((1, 1, 2048), dtype="float16") = lv488[1]
            rms_norm102: R.Tensor((1, 1, 2048), dtype="float16") = lv488[0]
            lv370 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_14_mlp_gate_up_proj_q_weight2, model_layers_14_mlp_gate_up_proj_q_scale2, rms_norm102), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv275 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv370,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv371 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_14_mlp_down_proj_q_weight2, model_layers_14_mlp_down_proj_q_scale2, lv275), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv490 = R.call_tir(cls.fuse_add_norm_prefill, (lv371, lv489, model_layers_15_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv491: R.Tensor((1, 1, 2048), dtype="float16") = lv490[1]
            rms_norm103: R.Tensor((1, 1, 2048), dtype="float16") = lv490[0]
            lv123 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_15_self_attn_c_attn_q_weight2, model_layers_15_self_attn_c_attn_q_scale2, rms_norm103, model_layers_15_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv277 = R.call_tir(cls.fused_reshape8_reshape9, (lv123,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv259_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), lv277), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv278 = R.call_tir(cls.fused_reshape10_reshape11, (lv259_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv372 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_15_self_attn_o_proj_q_weight2, model_layers_15_self_attn_o_proj_q_scale2, lv278), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv492 = R.call_tir(cls.fuse_add_norm_prefill, (lv372, lv491, model_layers_15_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv493: R.Tensor((1, 1, 2048), dtype="float16") = lv492[1]
            rms_norm104: R.Tensor((1, 1, 2048), dtype="float16") = lv492[0]
            lv373 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_15_mlp_gate_up_proj_q_weight2, model_layers_15_mlp_gate_up_proj_q_scale2, rms_norm104), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv279 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv373,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv374 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_15_mlp_down_proj_q_weight2, model_layers_15_mlp_down_proj_q_scale2, lv279), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv494 = R.call_tir(cls.fuse_add_norm_prefill, (lv374, lv493, model_layers_16_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv495: R.Tensor((1, 1, 2048), dtype="float16") = lv494[1]
            rms_norm105: R.Tensor((1, 1, 2048), dtype="float16") = lv494[0]
            lv124 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_16_self_attn_c_attn_q_weight2, model_layers_16_self_attn_c_attn_q_scale2, rms_norm105, model_layers_16_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv281 = R.call_tir(cls.fused_reshape8_reshape9, (lv124,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv264 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), lv281), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv282 = R.call_tir(cls.fused_reshape10_reshape11, (lv264,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv375 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_16_self_attn_o_proj_q_weight2, model_layers_16_self_attn_o_proj_q_scale2, lv282), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv496 = R.call_tir(cls.fuse_add_norm_prefill, (lv375, lv495, model_layers_16_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv497: R.Tensor((1, 1, 2048), dtype="float16") = lv496[1]
            rms_norm106: R.Tensor((1, 1, 2048), dtype="float16") = lv496[0]
            lv376 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_16_mlp_gate_up_proj_q_weight2, model_layers_16_mlp_gate_up_proj_q_scale2, rms_norm106), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv283 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv376,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv377 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_16_mlp_down_proj_q_weight2, model_layers_16_mlp_down_proj_q_scale2, lv283), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv498 = R.call_tir(cls.fuse_add_norm_prefill, (lv377, lv497, model_layers_17_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv499: R.Tensor((1, 1, 2048), dtype="float16") = lv498[1]
            rms_norm107: R.Tensor((1, 1, 2048), dtype="float16") = lv498[0]
            lv125 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_17_self_attn_c_attn_q_weight2, model_layers_17_self_attn_c_attn_q_scale2, rms_norm107, model_layers_17_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv285 = R.call_tir(cls.fused_reshape8_reshape9, (lv125,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv269_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), lv285), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv286 = R.call_tir(cls.fused_reshape10_reshape11, (lv269_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv378 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_17_self_attn_o_proj_q_weight2, model_layers_17_self_attn_o_proj_q_scale2, lv286), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv500 = R.call_tir(cls.fuse_add_norm_prefill, (lv378, lv499, model_layers_17_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv501: R.Tensor((1, 1, 2048), dtype="float16") = lv500[1]
            rms_norm108: R.Tensor((1, 1, 2048), dtype="float16") = lv500[0]
            lv379 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_17_mlp_gate_up_proj_q_weight2, model_layers_17_mlp_gate_up_proj_q_scale2, rms_norm108), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv287 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv379,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv380 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_17_mlp_down_proj_q_weight2, model_layers_17_mlp_down_proj_q_scale2, lv287), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv502 = R.call_tir(cls.fuse_add_norm_prefill, (lv380, lv501, model_layers_18_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv503: R.Tensor((1, 1, 2048), dtype="float16") = lv502[1]
            rms_norm109: R.Tensor((1, 1, 2048), dtype="float16") = lv502[0]
            lv126 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_18_self_attn_c_attn_q_weight2, model_layers_18_self_attn_c_attn_q_scale2, rms_norm109, model_layers_18_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv289 = R.call_tir(cls.fused_reshape8_reshape9, (lv126,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv274_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), lv289), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv290 = R.call_tir(cls.fused_reshape10_reshape11, (lv274_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv381 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_18_self_attn_o_proj_q_weight2, model_layers_18_self_attn_o_proj_q_scale2, lv290), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv504 = R.call_tir(cls.fuse_add_norm_prefill, (lv381, lv503, model_layers_18_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv505: R.Tensor((1, 1, 2048), dtype="float16") = lv504[1]
            rms_norm110: R.Tensor((1, 1, 2048), dtype="float16") = lv504[0]
            lv382 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_18_mlp_gate_up_proj_q_weight2, model_layers_18_mlp_gate_up_proj_q_scale2, rms_norm110), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv291 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv382,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv383 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_18_mlp_down_proj_q_weight2, model_layers_18_mlp_down_proj_q_scale2, lv291), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv506 = R.call_tir(cls.fuse_add_norm_prefill, (lv383, lv505, model_layers_19_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv507: R.Tensor((1, 1, 2048), dtype="float16") = lv506[1]
            rms_norm111: R.Tensor((1, 1, 2048), dtype="float16") = lv506[0]
            lv127 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_19_self_attn_c_attn_q_weight2, model_layers_19_self_attn_c_attn_q_scale2, rms_norm111, model_layers_19_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv293 = R.call_tir(cls.fused_reshape8_reshape9, (lv127,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv279_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), lv293), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv294 = R.call_tir(cls.fused_reshape10_reshape11, (lv279_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv384 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_19_self_attn_o_proj_q_weight2, model_layers_19_self_attn_o_proj_q_scale2, lv294), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv508 = R.call_tir(cls.fuse_add_norm_prefill, (lv384, lv507, model_layers_19_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv509: R.Tensor((1, 1, 2048), dtype="float16") = lv508[1]
            rms_norm112: R.Tensor((1, 1, 2048), dtype="float16") = lv508[0]
            lv385 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_19_mlp_gate_up_proj_q_weight2, model_layers_19_mlp_gate_up_proj_q_scale2, rms_norm112), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv295 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv385,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv386 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_19_mlp_down_proj_q_weight2, model_layers_19_mlp_down_proj_q_scale2, lv295), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv510 = R.call_tir(cls.fuse_add_norm_prefill, (lv386, lv509, model_layers_20_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv511: R.Tensor((1, 1, 2048), dtype="float16") = lv510[1]
            rms_norm113: R.Tensor((1, 1, 2048), dtype="float16") = lv510[0]
            lv128 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_20_self_attn_c_attn_q_weight2, model_layers_20_self_attn_c_attn_q_scale2, rms_norm113, model_layers_20_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv297 = R.call_tir(cls.fused_reshape8_reshape9, (lv128,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv284 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), lv297), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv298 = R.call_tir(cls.fused_reshape10_reshape11, (lv284,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv387 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_20_self_attn_o_proj_q_weight2, model_layers_20_self_attn_o_proj_q_scale2, lv298), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv512 = R.call_tir(cls.fuse_add_norm_prefill, (lv387, lv511, model_layers_20_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv513: R.Tensor((1, 1, 2048), dtype="float16") = lv512[1]
            rms_norm114: R.Tensor((1, 1, 2048), dtype="float16") = lv512[0]
            lv388 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_20_mlp_gate_up_proj_q_weight2, model_layers_20_mlp_gate_up_proj_q_scale2, rms_norm114), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv299 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv388,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv389 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_20_mlp_down_proj_q_weight2, model_layers_20_mlp_down_proj_q_scale2, lv299), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv514 = R.call_tir(cls.fuse_add_norm_prefill, (lv389, lv513, model_layers_21_input_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv515: R.Tensor((1, 1, 2048), dtype="float16") = lv514[1]
            rms_norm115: R.Tensor((1, 1, 2048), dtype="float16") = lv514[0]
            lv129 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul10_add2, (model_layers_21_self_attn_c_attn_q_weight2, model_layers_21_self_attn_c_attn_q_scale2, rms_norm115, model_layers_21_self_attn_c_attn_bias2), out_sinfo=R.Tensor((1, 1, 2560), dtype="float16"))
            lv301 = R.call_tir(cls.fused_reshape8_reshape9, (lv129,), out_sinfo=R.Tensor((1, 20, 128), dtype="float16"))
            lv289_1 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), lv301), out_sinfo=R.Tensor((1, 16, 128), dtype="float16"))
            lv302 = R.call_tir(cls.fused_reshape10_reshape11, (lv289_1,), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv390 = R.call_tir(cls.fused_dequantize2_NT_matmul11, (model_layers_21_self_attn_o_proj_q_weight2, model_layers_21_self_attn_o_proj_q_scale2, lv302), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv516 = R.call_tir(cls.fuse_add_norm_prefill, (lv390, lv515, model_layers_21_post_attention_layernorm_weight2), out_sinfo=[R.Tensor((1, 1, 2048), dtype="float16"), R.Tensor((1, 1, 2048), dtype="float16")])
            lv517: R.Tensor((1, 1, 2048), dtype="float16") = lv516[1]
            rms_norm116: R.Tensor((1, 1, 2048), dtype="float16") = lv516[0]
            lv391 = R.call_tir(cls.fused_dequantize3_NT_matmul12, (model_layers_21_mlp_gate_up_proj_q_weight2, model_layers_21_mlp_gate_up_proj_q_scale2, rms_norm116), out_sinfo=R.Tensor((1, 1, 22016), dtype="float16"))
            lv303 = R.call_tir(cls.fused_split2_silu2_multiply2, (lv391,), out_sinfo=R.Tensor((1, 1, 11008), dtype="float16"))
            lv392 = R.call_tir(cls.fused_dequantize4_NT_matmul13, (model_layers_21_mlp_down_proj_q_weight2, model_layers_21_mlp_down_proj_q_scale2, lv303), out_sinfo=R.Tensor((1, 1, 2048), dtype="float16"))
            lv518 = R.call_tir(cls.