# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func
    def apply_bitmask_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_bitmask: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
        # with T.block("root"):
        for fused_s_v_0 in T.thread_binding((num_seq * vocab_size + 1023) // 1024, thread="blockIdx.x"):
            for fused_s_v_1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * 1024 + fused_s_v_1) // vocab_size)
                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * 1024 + fused_s_v_1) % vocab_size)
                    T.where(fused_s_v_0 * 1024 + fused_s_v_1 < num_seq * vocab_size)
                    T.reads(bitmask[seq_ids[vs], vv // 32], seq_ids[vs], logits[seq_ids[vs], vv])
                    T.writes(logits[seq_ids[vs], vv])
                    logits[seq_ids[vs], vv] = T.if_then_else(T.bitwise_and(T.shift_right(bitmask[seq_ids[vs], vv // 32], vv % 32), 1) == 1, logits[seq_ids[vs], vv], T.float32(-340282346638528859811704183484516925440.0))

    @T.prim_func
    def apply_logit_bias_inplace(var_logits: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_logit_bias: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        logit_bias = T.match_buffer(var_logit_bias, (num_token,))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[pos2seq_id[vp], token_ids[vp]], pos2seq_id[vp], token_ids[vp], logit_bias[vp])
                    T.writes(logits[pos2seq_id[vp], token_ids[vp]])
                    logits[pos2seq_id[vp], token_ids[vp]] = logits[pos2seq_id[vp], token_ids[vp]] + logit_bias[vp]

    @T.prim_func
    def apply_penalty_inplace(var_logits: T.handle, var_seq_ids: T.handle, var_pos2seq_id: T.handle, var_token_ids: T.handle, var_token_cnt: T.handle, var_penalties: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        logits = T.match_buffer(var_logits, (batch_size, vocab_size))
        num_seq = T.int32(is_size_var=True)
        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
        num_token = T.int32(is_size_var=True)
        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
        token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
        token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
        penalties = T.match_buffer(var_penalties, (num_seq, 3))
        # with T.block("root"):
        for p0 in T.thread_binding((num_token + 1023) // 1024, thread="blockIdx.x"):
            for p1 in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("block"):
                    vp = T.axis.spatial(num_token, p0 * 1024 + p1)
                    T.where(p0 * 1024 + p1 < num_token)
                    T.reads(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]], seq_ids[pos2seq_id[vp]], pos2seq_id[vp], token_ids[vp], penalties[pos2seq_id[vp], 0:3], token_cnt[vp])
                    T.writes(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] - (penalties[pos2seq_id[vp], 0] + T.Cast("float32", token_cnt[vp]) * penalties[pos2seq_id[vp], 1])
                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0.0), logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2])

    @T.prim_func(private=True)
    def argsort(var_probs: T.handle, var_argsort_gpu_v1: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size), offset_factor=1)
        out_buf = T.match_buffer(var_argsort_gpu_v1, (batch_size, vocab_size), "int32", align=8)
        # with T.block("root"):
        value_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        value_swap_buf = T.alloc_buffer((batch_size, vocab_size), align=8)
        out_swap_buf = T.alloc_buffer((batch_size, vocab_size), "int32", align=8)
        with T.block("argsort_gpu"):
            T.reads()
            T.writes()
            if vocab_size > T.int64(0):
                with T.launch_thread("threadIdx.x", T.int64(1024)) as threadIdx_x:
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(1023)) // T.int64(1024)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    if blockIdx_x * T.int64(1024) + threadIdx_x < vocab_size:
                        value_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = probs[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size]
                        out_buf[(blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) // vocab_size, (blockIdx_y % batch_size * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x) + blockIdx_y // batch_size) % vocab_size] = T.Cast("int32", blockIdx_x * T.int64(1024) + threadIdx_x)
                with T.attr(0, "hand_threaded", 0):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(64))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(127)) // T.int64(128)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    temp_keys_swap = T.allocate([T.int64(128)], "float32", "shared")
                    temp_values_swap = T.allocate([T.int64(128)], "int32", "shared")
                    temp_keys = T.allocate([T.int64(1)], "float32", "local")
                    temp_values = T.allocate([T.int64(1)], "int32", "local")
                    temp_cond1 = T.allocate([T.int64(1)], "float32", "local")
                    temp_cond2 = T.allocate([T.int64(1)], "float32", "local")
                    temp_keys_swap_1 = T.Buffer((128,), data=temp_keys_swap, scope="shared")
                    temp_values_swap_1 = T.Buffer((128,), "int32", data=temp_values_swap, scope="shared")
                    for i in range(T.int64(2)):
                        if T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128) < vocab_size:
                            temp_keys_swap_1[T.int64(2) * threadIdx_x + i] = value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) % vocab_size]
                            temp_values_swap_1[T.int64(2) * threadIdx_x + i] = out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + i + blockIdx_x * T.int64(128))) % vocab_size]
                    T.tvm_storage_sync("shared")
                    for j in range(T.min(T.int64(128), vocab_size - blockIdx_x * T.int64(128))):
                        if T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) < T.min(T.int64(128), vocab_size - blockIdx_x * T.int64(128)) - T.int64(1):
                            temp_cond1_1 = T.Buffer((1,), data=temp_cond1, scope="local")
                            temp_cond1_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                            temp_cond2_1 = T.Buffer((1,), data=temp_cond2, scope="local")
                            temp_cond2_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                            if temp_cond1_1[T.int64(0)] < temp_cond2_1[T.int64(0)]:
                                temp_keys_1 = T.Buffer((1,), data=temp_keys, scope="local")
                                temp_keys_1[T.int64(0)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                                temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)] = temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                                temp_keys_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)] = temp_keys_1[T.int64(0)]
                                temp_values_1 = T.Buffer((1,), "int32", data=temp_values, scope="local")
                                temp_values_1[T.int64(0)] = temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)]
                                temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2)] = temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)]
                                temp_values_swap_1[T.int64(2) * threadIdx_x + (T.int64(2) * threadIdx_x + j) % T.int64(2) + T.int64(1)] = temp_values_1[T.int64(0)]
                        T.tvm_storage_sync("shared")
                    for k in range(T.int64(2)):
                        if T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128) < vocab_size:
                            value_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_keys_swap_1[T.int64(2) * threadIdx_x + k]
                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_keys_swap_1[T.int64(2) * threadIdx_x + k]
                            out_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_values_swap_1[T.int64(2) * threadIdx_x + k]
                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) // vocab_size, (blockIdx_y % batch_size * vocab_size + blockIdx_y // batch_size + (T.int64(2) * threadIdx_x + k + blockIdx_x * T.int64(128))) % vocab_size] = temp_values_swap_1[T.int64(2) * threadIdx_x + k]
                for i_0 in range(T.Cast("int64", T.ceil(T.log2(T.Cast("float64", vocab_size)))) - T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.max(T.int64(1), T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + T.int64(4095)) // T.int64(4096)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size * ((vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) - T.int64(1))) // T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))))
                    if T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) < vocab_size:
                        if (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + T.int64(4095)) // T.int64(4096) == T.int64(1):
                            if i_0 % T.int64(2) == T.int64(0):
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                i = T.allocate([T.int64(1)], "int64", "local")
                                j = T.allocate([T.int64(1)], "int64", "local")
                                i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]
                                j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - last_1[T.int64(0)]
                                for i_1_1 in range(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)) - threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))))):
                                    if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)):
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                    else:
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)):
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_1_1)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                            else:
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                i = T.allocate([T.int64(1)], "int64", "local")
                                j = T.allocate([T.int64(1)], "int64", "local")
                                i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]
                                j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) - last_1[T.int64(0)]
                                for i_2 in range(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)) - threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))), (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))))):
                                    if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)):
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                    else:
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size)):
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                            i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                        else:
                                            value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + threadIdx_x * ((T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) + (T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))))) - T.int64(1))) // T.min(T.int64(1024), T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))))) + i_2)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                            j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                        else:
                            if i_0 % T.int64(2) == T.int64(0):
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), blockIdx_x * T.int64(4096) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(blockIdx_x * T.int64(4096), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                if i_0 % T.int64(2) == T.int64(0):
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_3 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_3)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                else:
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_4 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_4)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                            else:
                                first = T.allocate([T.int64(1)], "int64", "local")
                                mid = T.allocate([T.int64(1)], "int64", "local")
                                last = T.allocate([T.int64(1)], "int64", "local")
                                first_1 = T.Buffer((1,), "int64", data=first, scope="local")
                                first_1[T.int64(0)] = T.max(T.int64(0), blockIdx_x * T.int64(4096) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size)))
                                last_1 = T.Buffer((1,), "int64", data=last, scope="local")
                                last_1[T.int64(0)] = T.min(blockIdx_x * T.int64(4096), T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size))
                                while first_1[T.int64(0)] < last_1[T.int64(0)]:
                                    if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - T.int64(1) - T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)))) % vocab_size]:
                                        first_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1)) + T.int64(1)
                                    else:
                                        last_1[T.int64(0)] = T.shift_right(first_1[T.int64(0)] + last_1[T.int64(0)], T.int64(1))
                                if i_0 % T.int64(2) == T.int64(0):
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_5 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = value_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_5)) % vocab_size] = out_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                else:
                                    first_2 = T.allocate([T.int64(1)], "int64", "local")
                                    mid_1 = T.allocate([T.int64(1)], "int64", "local")
                                    last_2 = T.allocate([T.int64(1)], "int64", "local")
                                    first_3 = T.Buffer((1,), "int64", data=first_2, scope="local")
                                    first_3[T.int64(0)] = T.max(T.int64(0), threadIdx_x * T.int64(4) - T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)))
                                    last_3 = T.Buffer((1,), "int64", data=last_2, scope="local")
                                    last_3[T.int64(0)] = T.min(threadIdx_x * T.int64(4), T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)))
                                    while first_3[T.int64(0)] < last_3[T.int64(0)]:
                                        if value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - T.int64(1) - T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)))) % vocab_size]:
                                            first_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1)) + T.int64(1)
                                        else:
                                            last_3[T.int64(0)] = T.shift_right(first_3[T.int64(0)] + last_3[T.int64(0)], T.int64(1))
                                    i = T.allocate([T.int64(1)], "int64", "local")
                                    j = T.allocate([T.int64(1)], "int64", "local")
                                    i_1 = T.Buffer((1,), "int64", data=i, scope="local")
                                    i_1[T.int64(0)] = T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + first_3[T.int64(0)]
                                    j_1 = T.Buffer((1,), "int64", data=j, scope="local")
                                    j_1[T.int64(0)] = T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + threadIdx_x * T.int64(4) - last_3[T.int64(0)]
                                    for i_6 in range(T.min(T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)) - threadIdx_x * T.int64(4), T.int64(4))):
                                        if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)) and j_1[T.int64(0)] < T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))), vocab_size) - (T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) + blockIdx_x * T.int64(4096) - last_1[T.int64(0)]), T.int64(4096)):
                                            if value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size] <= value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                                        else:
                                            if i_1[T.int64(0)] < T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)] + T.min(T.min(T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) // T.int64(2), vocab_size) - (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + first_1[T.int64(0)]), T.int64(4096)):
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + i_1[T.int64(0)]) % vocab_size]
                                                i_1[T.int64(0)] = i_1[T.int64(0)] + T.int64(1)
                                            else:
                                                value_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = value_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                out_buf[(blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) // vocab_size, (blockIdx_y % batch_size * vocab_size + (T.shift_left(T.int64(2), i_0 + T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) * (blockIdx_y // batch_size) + blockIdx_x * T.int64(4096) + threadIdx_x * T.int64(4) + i_6)) % vocab_size] = out_swap_buf[(blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) // vocab_size, (blockIdx_y % batch_size * vocab_size + j_1[T.int64(0)]) % vocab_size]
                                                j_1[T.int64(0)] = j_1[T.int64(0)] + T.int64(1)
                if T.Cast("int64", T.ceil(T.log2(T.Cast("float64", vocab_size)))) > T.Cast("int64", T.ceil(T.log2(T.float64(128.0)))) and (T.Cast("int64", T.ceil(T.log2(T.Cast("float64", vocab_size)))) - T.Cast("int64", T.ceil(T.log2(T.float64(128.0))))) % T.int64(2) == T.int64(1):
                    threadIdx_x = T.launch_thread("threadIdx.x", T.int64(1024))
                    blockIdx_x = T.launch_thread("blockIdx.x", T.max(T.int64(1), (vocab_size + T.int64(1023)) // T.int64(1024)))
                    blockIdx_y = T.launch_thread("blockIdx.y", T.max(T.int64(1), batch_size))
                    if blockIdx_x * T.int64(1024) + threadIdx_x < vocab_size:
                        value_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size] = value_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size]
                        out_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size] = out_swap_buf[(blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) // vocab_size, (blockIdx_y * vocab_size + (blockIdx_x * T.int64(1024) + threadIdx_x)) % vocab_size]

    @T.prim_func
    def batch_decode_paged_kv(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 16, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (B,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(16, thread="blockIdx.y"):
                for ty in T.thread_binding(1, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(16, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[bx], q_rope_position[bx], Q[bx, fused_by_bz // 16 + ty + fused_by_bz % 16, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 16 + fused_by_bz // 16 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 16 + fused_by_bz // 16 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((16, 1, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((16, 1, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((2,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 16
                                bz: T.int32 = fused_by_bz // 16
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by + bz + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by + bz + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 31) // 32):
                                    tile_start_s: T.int32 = (tz + ty) * 2
                                    tile_start_g: T.int32 = (iterator * 16 + tz + ty) * 2
                                    for j in range(2):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = row_g
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 16 + tz) * 2 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(2):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(16):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_decode_paged_kv_sliding_window(_0: T.int32, Q_handle: T.handle, pages_handle: T.handle, page_table_indptr_handle: T.handle, page_table_values_handle: T.handle, var_length_info: T.handle, k_rope_pos_offset_handle: T.handle, q_rope_position_handle: T.handle, output_handle: T.handle, lse_handle: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        B = T.int32(is_size_var=True)
        Q = T.match_buffer(Q_handle, (B, 16, 128), "float16")
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(pages_handle, (max_num_pages, 2, 16, 16, 128), "float16", offset_factor=1)
        page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, B), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(k_rope_pos_offset_handle, (B,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(q_rope_position_handle, (B,), "int32", offset_factor=1)
        output = T.match_buffer(output_handle, (B, 16, 128), "float16")
        lse = T.match_buffer(lse_handle, (B, 16))
        # with T.block("root"):
        sm_scale: T.float32 = T.float32(0.12751743082459868)
        for bx in T.thread_binding(B, thread="blockIdx.x"):
            for fused_by_bz in T.thread_binding(16, thread="blockIdx.y"):
                for ty in T.thread_binding(1, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        for tz in T.thread_binding(16, thread="threadIdx.z"):
                            with T.block("attn"):
                                T.reads(page_table_indptr[bx:bx + 2], length_info[0:3, bx], q_rope_position[bx], Q[bx, fused_by_bz // 16 + ty + fused_by_bz % 16, tx * 4 - 64:tx * 4 - 64 + 132])
                                T.writes(output[bx, fused_by_bz % 16 + fused_by_bz // 16 + ty, tx * 4:tx * 4 + 4], lse[bx, fused_by_bz % 16 + fused_by_bz // 16 + ty])
                                Q_local = T.alloc_buffer((4,), "float16", scope="local")
                                kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                                K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                                V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                                O_allreduce = T.alloc_buffer((16, 1, 128), scope="shared")
                                md_allreduce = T.alloc_buffer((16, 1, 2), scope="shared")
                                S_reduce_local = T.alloc_buffer((1,), scope="local")
                                t0 = T.alloc_buffer((1,), scope="local")
                                S_local = T.alloc_buffer((2,), scope="local")
                                QK_local = T.alloc_buffer((4,), scope="local")
                                V_local = T.alloc_buffer((4,), "float16", scope="local")
                                m_prev = T.alloc_buffer((1,), scope="local")
                                d_prev = T.alloc_buffer((1,), scope="local")
                                other_m = T.alloc_buffer((1,), scope="local")
                                other_d = T.alloc_buffer((1,), scope="local")
                                exp_mprev = T.alloc_buffer((1,), scope="local")
                                exp_otherm = T.alloc_buffer((1,), scope="local")
                                other_o = T.alloc_buffer((4,), scope="local")
                                st_m = T.alloc_buffer((1,), scope="local")
                                st_d = T.alloc_buffer((1,), scope="local")
                                O_local = T.alloc_buffer((4,), scope="local")
                                by: T.int32 = fused_by_bz % 16
                                bz: T.int32 = fused_by_bz // 16
                                batch_idx: T.int32 = bx
                                cur_page_indptr_begin: T.int32 = page_table_indptr[batch_idx]
                                cur_page_indptr_end: T.int32 = page_table_indptr[batch_idx + 1]
                                kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, batch_idx] - length_info[1, batch_idx] + length_info[2, batch_idx], 0)
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for vec in T.vectorized(4):
                                    freq = T.float32()
                                    Q_local[vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", Q[bx, by + bz + ty, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, Q[bx, by + bz + ty, tx * 4 + vec + 64] * T.float16(-1.0), Q[bx, by + bz + ty, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", q_rope_position[batch_idx]) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), Q[bx, by + bz + ty, tx * 4 + vec])
                                for iterator in range((kv_chunk_len[0] + 31) // 32):
                                    tile_start_s: T.int32 = (tz + ty) * 2
                                    tile_start_g: T.int32 = (iterator * 16 + tz + ty) * 2
                                    for j in range(2):
                                        with T.block("KV_load"):
                                            T.reads()
                                            T.writes()
                                            row_g: T.int32 = tile_start_g + j
                                            if row_g < kv_chunk_len[0]:
                                                seq_offset: T.int32 = T.if_then_else(row_g < length_info[2, batch_idx], row_g, row_g - length_info[2, batch_idx] + length_info[1, batch_idx])
                                                page_no: T.int32 = page_table_values[cur_page_indptr_begin + seq_offset // 16]
                                                page_offset: T.int32 = seq_offset % 16
                                                for vec in T.vectorized(4):
                                                    freq = T.float32()
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, tx * 4 + vec]) + T.sin(freq) * T.Cast("float32", T.if_then_else(tx * 4 + vec < 64, pages[page_no, 0, by, page_offset, tx * 4 + vec + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, tx * 4 + vec - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[batch_idx] + row_g) * rope_scale / T.pow(rope_theta, T.Cast("float32", (tx * 4 + vec) * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, tx * 4 + vec])
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = pages[page_no, 1, by, page_offset, tx * 4 + vec]
                                            else:
                                                for vec in T.vectorized(4):
                                                    K_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                                    V_smem[tile_start_s + j, tx * 4 + vec] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    m_prev[0] = st_m[0]
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            QK_local[vec] = T.Cast("float32", Q_local[vec]) * T.Cast("float32", K_smem[tz * 2 + j, tx * 4 + vec]) * attn_score_scaling_factor * sm_scale
                                        S_reduce_local[0] = T.float32(0.0)
                                        for vec in T.unroll(4):
                                            S_reduce_local[0] = S_reduce_local[0] + QK_local[vec]
                                        with T.block("block_cross_thread"):
                                            T.reads(S_reduce_local[0])
                                            T.writes(t0[0])
                                            T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                            T.tvm_thread_allreduce(T.uint32(1), S_reduce_local[0], T.bool(True), t0[0], tx)
                                        S_local[j] = T.float32(-50000.0)
                                        if (iterator * 16 + tz) * 2 + j < kv_chunk_len[0]:
                                            S_local[j] = t0[0]
                                        st_m[0] = T.max(st_m[0], S_local[j])
                                    o_scale: T.float32 = T.exp2(m_prev[0] - st_m[0])
                                    st_d[0] = st_d[0] * o_scale
                                    for j in range(2):
                                        S_local[j] = T.exp2(S_local[j] - st_m[0])
                                        st_d[0] = st_d[0] + S_local[j]
                                    for j in T.vectorized(4):
                                        O_local[j] = O_local[j] * o_scale
                                    for j in range(2):
                                        for vec in T.vectorized(4):
                                            V_local[vec] = V_smem[tz * 2 + j, tx * 4 + vec]
                                        for vec in T.vectorized(4):
                                            O_local[vec] = O_local[vec] + T.Cast("float32", V_local[vec]) * S_local[j]
                                for vec in T.vectorized(4):
                                    O_allreduce[tz, ty, tx * 4 + vec] = O_local[vec]
                                md_allreduce[tz, ty, 0] = st_m[0]
                                md_allreduce[tz, ty, 1] = st_d[0]
                                T.tvm_storage_sync("shared")
                                st_m[0] = T.float32(-50000.0)
                                st_d[0] = T.float32(1.0)
                                for vec in T.vectorized(4):
                                    O_local[vec] = T.float32(0.0)
                                for j in range(16):
                                    m_prev[0] = st_m[0]
                                    d_prev[0] = st_d[0]
                                    other_m[0] = md_allreduce[j, ty, 0]
                                    other_d[0] = md_allreduce[j, ty, 1]
                                    for vec in T.vectorized(4):
                                        other_o[vec] = O_allreduce[j, ty, tx * 4 + vec]
                                    st_m[0] = T.max(st_m[0], other_m[0])
                                    st_d[0] = d_prev[0] * T.exp2(m_prev[0] - st_m[0]) + other_d[0] * T.exp2(other_m[0] - st_m[0])
                                    exp_mprev[0] = T.exp2(m_prev[0] - st_m[0])
                                    exp_otherm[0] = T.exp2(other_m[0] - st_m[0])
                                    for vec in T.vectorized(4):
                                        O_local[vec] = O_local[vec] * exp_mprev[0] + other_o[vec] * exp_otherm[0]
                                for vec in T.vectorized(4):
                                    O_local[vec] = O_local[vec] / st_d[0]
                                for vec in T.vectorized(4):
                                    output[batch_idx, by + bz + ty, tx * 4 + vec] = T.Cast("float16", O_local[vec])
                                lse[batch_idx, by + bz + ty] = st_m[0] + T.log2(st_d[0])

    @T.prim_func
    def batch_prefill_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 16, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(16, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                            cur_H_qo: T.int32 = by
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_paged_kv_sliding_window(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 16, 16, 128), "float16", offset_factor=1)
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (3, batch_size), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(16, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[0, b_idx] - length_info[1, b_idx] + length_info[2, b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", pages[page_no, 0, by, page_offset, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, pages[page_no, 0, by, page_offset, j + 64] * T.float16(-1.0), pages[page_no, 0, by, page_offset, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), pages[page_no, 0, by, page_offset, j])
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = T.if_then_else(cur_L < length_info[2, b_idx], cur_L, cur_L - length_info[2, b_idx] + length_info[1, b_idx])
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                            cur_H_qo: T.int32 = by
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_prefill_ragged_kv(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_k_rope_pos_offset: T.handle, var_output: T.handle, var_lse: T.handle, causal: T.int32, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 16, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 16, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(16, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((128, 32), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_1_lj_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_1_lj_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_0, lj_0_0 in T.grid(2, 4):
                                                for lj_1 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, li_0 * 16 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) // 8)
                                                        j = T.axis.spatial(128, lj_0_0 * 32 + (li_1_lj_0_1_fused_0 * 32 + li_1_lj_0_1_fused_1) % 8 * 4 + lj_1)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads(kv_chunk_len[0], k_rope_pos_offset[b_idx], k[L_kv_base + L_kv_start + i, by, j - 64:j - 64 + 129])
                                                            T.writes(K_smem[j, i])
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[j, i] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[L_kv_base + cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[L_kv_base + cur_L, by, j + 64] * T.float16(-1.0), k[L_kv_base + cur_L, by, j - 64]))), where={freq: T.Cast("float32", k_rope_pos_offset[b_idx] + cur_L) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[L_kv_base + cur_L, by, j])
                                                            else:
                                                                K_smem[j, i] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_1_ly_0_1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for lz_1_ly_0_1_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for lz_0, ly_0_0 in T.grid(2, 4):
                                                    for ly_1 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, lz_0 * 16 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) // 8)
                                                            j = T.axis.spatial(128, ly_0_0 * 32 + (lz_1_ly_0_1_fused_0 * 32 + lz_1_ly_0_1_fused_1) % 8 * 4 + ly_1)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                V_smem[i, j] = v[L_kv_base + cur_L, by, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:128, 0:32])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(2):
                                                        for lj_1_0_init in T.unroll(1):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("S_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                                    j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(S_local[i, j])
                                                                    S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0 in range(8):
                                                        for li_1 in T.unroll(2):
                                                            for lj_1_0 in T.unroll(1):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    for lk_1 in range(16):
                                                                        with T.block("S_gemm_update"):
                                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                            k_1 = T.axis.reduce(128, lk_0 * 16 + lk_1)
                                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[k_1, j])
                                                                            T.writes(S_local[i, j])
                                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[k_1, j]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1 in range(2):
                                                    for lj_1_0 in T.unroll(1):
                                                        for lj_1_1 in T.vectorized(4):
                                                            with T.block("S_store"):
                                                                i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                                j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1_0 * 4 + lj_1_1)
                                                                T.reads(S_local[i, j])
                                                                T.writes(S_smem[i, j])
                                                                S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if T.if_then_else(causal > 0, L_kv_start + j < kv_chunk_len[0] - (q_indptr[b_idx + 1] - q_indptr[b_idx]) + row_ + 1, L_kv_start + j < kv_chunk_len[0]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init in T.unroll(4):
                                                        for lj_1_0_init in T.unroll(2):
                                                            for lj_1_1_init in T.vectorized(4):
                                                                with T.block("O_gemm_init"):
                                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_0_init * 4 + lj_1_1_init)
                                                                    T.reads()
                                                                    T.writes(O_local[i, j])
                                                                    O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1 in T.grid(2, 16):
                                                        for li_1 in T.unroll(4):
                                                            for lj_1_0 in T.unroll(2):
                                                                for lj_1_1 in T.vectorized(4):
                                                                    with T.block("O_gemm_update"):
                                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                                        j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                                        k_1 = T.axis.reduce(32, lk_0 * 16 + lk_1)
                                                                        T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                                        T.writes(O_local[i, j])
                                                                        O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1 in range(4):
                                                for lj_1_0 in T.unroll(2):
                                                    for lj_1_1 in T.vectorized(4):
                                                        with T.block("O_store"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1_0 * 4 + lj_1_1)
                                                            T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                            T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                            cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                            cur_H_qo: T.int32 = by
                                                            if cur_L < q_indptr[b_idx + 1]:
                                                                output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func
    def batch_tree_attn(var_q: T.handle, var_q_indptr: T.handle, var_k: T.handle, var_v: T.handle, var_kv_indptr: T.handle, var_q_rope_position: T.handle, var_mn_indptr: T.handle, var_mask: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, batch_size: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        qo_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (qo_len, 16, 128), "float16")
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        kv_len = T.int32(is_size_var=True)
        k = T.match_buffer(var_k, (kv_len, 16, 128), "float16")
        v = T.match_buffer(var_v, (kv_len, 16, 128), "float16")
        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", offset_factor=1)
        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", offset_factor=1)
        tree_size = T.int32(is_size_var=True)
        mask = T.match_buffer(var_mask, (tree_size, 2), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (qo_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (qo_len, 16))
        # with T.block("root"):
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(16, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        L_kv_base: T.int32 = kv_indptr[b_idx]
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("KV_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_base + L_kv_start + i
                                                            if L_kv_start + i < kv_chunk_len[0]:
                                                                freq = T.float32()
                                                                K_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", k[cur_L, by, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, k[cur_L, by, j + 64] * T.float16(-1.0), k[cur_L, by, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), k[cur_L, by, j])
                                                                V_smem[i, j] = v[cur_L, by, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k_1 = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k_1], K_smem[j, k_1])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k_1]) * T.Cast("float32", K_smem[j, k_1]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], mn_indptr[b_idx:b_idx + 2], mask[T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + mn_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + mn_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) or mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 0] and mask[mn_indptr[b_idx] + (row_ + (mn_indptr[b_idx + 1] - mn_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < mask[mn_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (mn_indptr[b_idx + 1] - mn_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k_1 = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k_1], V_smem[k_1, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k_1] * T.Cast("float32", V_smem[k_1, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @T.prim_func(private=True)
    def batch_verify_on_gpu_single_kernel(var_draft_probs: T.handle, var_draft_tokens: T.handle, var_model_probs: T.handle, var_token_tree_first_child: T.handle, var_token_tree_next_sibling: T.handle, var_uniform_samples: T.handle, var_token_tree_parent_ptr: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        num_nodes, vocab_size = T.int32(is_size_var=True), T.int64(is_size_var=True)
        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size))
        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), "int32")
        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size))
        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), "int32")
        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), "int32")
        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,))
        nbatch = T.int32(is_size_var=True)
        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), "int32")
        # with T.block("root"):
        child_ptr = T.alloc_buffer((1,), "int32", scope="local")
        parent_ptr = T.alloc_buffer((1,), "int32", scope="local")
        child_token = T.alloc_buffer((1,), "int32", scope="local")
        done = T.alloc_buffer((1,), "bool", scope="local")
        psum = T.alloc_buffer((1,), scope="local")
        t0 = T.alloc_buffer((1,), scope="local")
        model_prob_local = T.alloc_buffer((1,), scope="local")
        draft_prob_local = T.alloc_buffer((1,), scope="local")
        p_child = T.alloc_buffer((1,), scope="local")
        q_child = T.alloc_buffer((1,), scope="local")
        uniform_sample = T.alloc_buffer((1,), scope="local")
        pred_shared = T.alloc_buffer((1,), "bool", scope="shared")
        pred_local = T.alloc_buffer((1,), "bool", scope="local")
        for _bx in T.thread_binding(nbatch, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(token_tree_parent_ptr[b], token_tree_first_child[T.min(parent_ptr[0], child_ptr[0]):T.min(parent_ptr[0], child_ptr[0]) + (T.max(parent_ptr[0], child_ptr[0]) + 1 - T.min(parent_ptr[0], child_ptr[0]))], parent_ptr[0], done[0], child_ptr[0], draft_tokens[child_ptr[0]], model_probs[parent_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], child_token[0], draft_probs[child_ptr[0], T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)):T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)) + (T.max(T.Cast("int64", child_token[0]), (vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) + T.Cast("int64", tx) - T.int64(1024)) + T.int64(1) - T.min(T.Cast("int64", child_token[0]), T.Cast("int64", tx)))], uniform_samples[child_ptr[0]], p_child[0], uniform_sample[0], q_child[0], pred_shared[0], pred_local[0], model_prob_local[0], draft_prob_local[0], psum[0], t0[0], token_tree_next_sibling[child_ptr[0]])
                    T.writes(parent_ptr[0], child_ptr[0], done[0], child_token[0], p_child[0], q_child[0], uniform_sample[0], pred_shared[0], pred_local[0], psum[0], model_prob_local[0], draft_prob_local[0], t0[0], model_probs[parent_ptr[0], T.Cast("int64", tx):T.Cast("int64", tx) + ((vocab_size + T.int64(1023)) // T.int64(1024) * T.int64(1024) - T.int64(1023))], token_tree_parent_ptr[b])
                    parent_ptr[0] = token_tree_parent_ptr[b]
                    child_ptr[0] = token_tree_first_child[parent_ptr[0]]
                    done[0] = T.bool(False)
                    while not done[0]:
                        T.tvm_storage_sync("shared")
                        if child_ptr[0] == -1:
                            done[0] = T.bool(True)
                            T.tvm_storage_sync("shared")
                        else:
                            if tx == 0:
                                child_token[0] = draft_tokens[child_ptr[0]]
                                p_child[0] = model_probs[parent_ptr[0], child_token[0]]
                                q_child[0] = draft_probs[child_ptr[0], child_token[0]]
                                uniform_sample[0] = uniform_samples[child_ptr[0]]
                                pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]
                            T.tvm_storage_sync("shared")
                            pred_local[0] = pred_shared[0]
                            if pred_local[0]:
                                parent_ptr[0] = child_ptr[0]
                                child_ptr[0] = token_tree_first_child[child_ptr[0]]
                            else:
                                psum[0] = T.float32(0.0)
                                for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                    if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                        model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                        model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                        psum[0] = psum[0] + model_prob_local[0]
                                with T.block("block_cross_thread"):
                                    T.reads(psum[0])
                                    T.writes(t0[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), psum[0], T.bool(True), t0[0], tx)
                                if t0[0] < T.float32(9.9999999999999995e-08):
                                    parent_ptr[0] = child_ptr[0]
                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]
                                else:
                                    for i in range((vocab_size + T.int64(1023)) // T.int64(1024)):
                                        if i * T.int64(1024) + T.Cast("int64", tx) < vocab_size:
                                            model_prob_local[0] = model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            draft_prob_local[0] = draft_probs[child_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)]
                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], T.float32(0.0))
                                            model_probs[parent_ptr[0], i * T.int64(1024) + T.Cast("int64", tx)] = model_prob_local[0] / t0[0]
                                    child_ptr[0] = token_tree_next_sibling[child_ptr[0]]
                    if tx == 0:
                        token_tree_parent_ptr[b] = parent_ptr[0]

    @T.prim_func
    def chunk_lse(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        # with T.block("root"):
        A_pad = T.alloc_buffer((batch_size, num_chunks, T.int64(4096)))
        temp_max = T.alloc_buffer((batch_size, num_chunks))
        temp_sum = T.alloc_buffer((batch_size, num_chunks))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("pad"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2])
                T.writes(A_pad[v0, v1, v2])
                A_pad[v0, v1, v2] = T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), A[v0, v1 * T.int64(4096) + v2] / temperature[v0], A[v0, v1 * T.int64(4096) + v2]), T.float32(-340282346638528859811704183484516925440.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("max"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(A_pad[v0, v1, v2])
                T.writes(temp_max[v0, v1])
                with T.init():
                    temp_max[v0, v1] = T.float32(-340282346638528859811704183484516925440.0)
                temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(4096)):
            with T.block("sum_exp"):
                v0, v1, v2 = T.axis.remap("SSR", [l0, l1, l2])
                T.reads(temperature[v0], A_pad[v0, v1, v2], temp_max[v0, v1])
                T.writes(temp_sum[v0, v1])
                with T.init():
                    temp_sum[v0, v1] = T.float32(0.0)
                temp_sum[v0, v1] = temp_sum[v0, v1] + T.if_then_else(v1 * T.int64(4096) + v2 < vocab_size, T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]), T.Cast("float32", A_pad[v0, v1, v2] == temp_max[v0, v1])), T.float32(0.0))
        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):
            with T.block("log"):
                v0, v1, v2 = T.axis.remap("SSS", [l0, l1, l2])
                T.reads(temperature[v0], temp_sum[v0, v1], temp_max[v0, v1])
                T.writes(chunked_sum[v0, v1], chunked_max[v0, v1])
                chunked_sum[v0, v1] = T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.log(temp_sum[v0, v1]), temp_sum[v0, v1])
                chunked_max[v0, v1] = temp_max[v0, v1]

    @T.prim_func
    def compact_kv_copy(var_pages: T.handle, var_copy_length_indptr: T.handle, var_copy_src_dst_pos: T.handle, batch_size: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        num_pages = T.int32()
        pages = T.match_buffer(var_pages, (num_pages, 2, 16, 16, 128), "float16", offset_factor=1)
        copy_length_indptr = T.match_buffer(var_copy_length_indptr, (batch_size + 1,), "int32", offset_factor=1)
        total_copy_length = T.int32()
        copy_src_dst_pos = T.match_buffer(var_copy_src_dst_pos, (2, total_copy_length), "int32", offset_factor=1)
        with T.block("root"):
            T.reads()
            T.writes()
            for bhd_o in T.thread_binding(batch_size * 2, thread="blockIdx.x"):
                for bhd_i in T.thread_binding(1024, thread="threadIdx.x"):
                    b: T.int32 = (bhd_o * 1024 + bhd_i) // 2048
                    h: T.int32 = (bhd_o * 1024 + bhd_i) // 128 % 16
                    d: T.int32 = (bhd_o * 1024 + bhd_i) % 128
                    if bhd_o * 1024 + bhd_i < batch_size * 16 * 128:
                        for i in range(copy_length_indptr[b + 1] - copy_length_indptr[b]):
                            src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i]
                            dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i]
                            pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[src_pos // 16, 0, h, src_pos % 16, d]
                            pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[src_pos // 16, 1, h, src_pos % 16, d]

    @T.prim_func
    def copy_single_page(var_pages: T.handle, src_page_id: T.int64, tgt_page_id: T.int64, copy_length: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        num_pages, page_size = T.int32(), T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 16, page_size, 128), "float16", offset_factor=1)
        # with T.block("root"):
        for b in T.thread_binding(copy_length * T.int64(2), thread="blockIdx.x"):
            for t in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("copy"):
                    vh = T.axis.spatial(16, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) // (copy_length * T.int64(128))))
                    vp = T.axis.spatial(copy_length, (b * T.int64(1024) + T.Cast("int64", t)) % (copy_length * T.int64(128)) // T.int64(128))
                    vd = T.axis.spatial(128, T.Cast("int32", (b * T.int64(1024) + T.Cast("int64", t)) % T.int64(128)))
                    T.where(b * T.int64(1024) + T.Cast("int64", t) < copy_length * T.int64(16) * T.int64(128))
                    T.reads(pages[src_page_id, 0:2, vh, vp, vd])
                    T.writes(pages[tgt_page_id, 0:2, vh, vp, vd])
                    pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd]
                    pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd]

    @T.prim_func(private=True)
    def dequantize_group_gemm(var_x: T.handle, w: T.Buffer((60, 2816, 256), "uint32"), scale: T.Buffer((60, 2816, 64), "float16"), indptr: T.Buffer((61,), "int32"), var_o: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B = T.int32(is_size_var=True)
        X = T.match_buffer(var_x, (B, 2048), "float16")
        O = T.match_buffer(var_o, (B, 2816), "float16")
        # with T.block("root"):
        for _bx in T.thread_binding(1024, thread="blockIdx.x"):
            with T.block("CTA"):
                bx = T.axis.spatial(1024, _bx)
                T.reads(X[0:B, 0:2048], w[0:60, 0:2816, 0:256], scale[0:60, 0:2816, 0:64], indptr[0:61])
                T.writes(O[0:B, 0:2816])
                sum = T.alloc_buffer((2,), "int32", scope="local")
                row = T.alloc_buffer((2,), "int32", scope="local")
                cur_e = T.alloc_buffer((1,), "int32", scope="local")
                tile_id = T.alloc_buffer((1,), "int32", scope="local")
                sum[0] = 0
                sum[1] = (indptr[1] - indptr[0] + 8 - 1) // 8 * 22
                row[0] = 0
                row[1] = indptr[1] - indptr[0]
                cur_e[0] = 0
                tile_id[0] = bx
                while T.tvm_thread_invariant(cur_e[0] < 60):
                    while sum[1] <= tile_id[0] and cur_e[0] < 60:
                        cur_e[0] = cur_e[0] + 1
                        if cur_e[0] < 60:
                            e: T.int32 = cur_e[0]
                            delta: T.int32 = indptr[e + 1] - indptr[e]
                            sum[0] = sum[1]
                            sum[1] = sum[1] + (delta + 8 - 1) // 8 * 22
                            row[0] = row[1]
                            row[1] = row[1] + delta
                    T.tvm_storage_sync("shared")
                    if T.tvm_thread_invariant(cur_e[0] < 60):
                        e: T.int32 = cur_e[0]
                        num_tiles: T.int32 = tile_id[0] - sum[0]
                        m_offset: T.int32 = num_tiles // 22 * 8 + row[0]
                        n_offset: T.int32 = num_tiles % 22 * 128
                        with T.block("gemm"):
                            T.reads(row[1], X[m_offset:m_offset + 8, 0:2048], w[e, n_offset:n_offset + 128, 0:256], scale[e, n_offset:n_offset + 128, 0:64])
                            T.writes(O[m_offset:m_offset + 8, n_offset:n_offset + 128])
                            X_tile = T.alloc_buffer((8, 2048), "float16", scope="shared")
                            W_tile = T.alloc_buffer((128, 2048), "float16", scope="shared")
                            O_tile = T.alloc_buffer((8, 128), scope="local")
                            for a1_0 in T.thread_binding(32, thread="threadIdx.y"):
                                for a0_0 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                                    for a1_1_init, a0_1_init in T.grid(4, 1):
                                        for a0_2_init in T.vectorized(1):
                                            with T.block("compute_init"):
                                                i = T.axis.spatial(8, a0_0 + a0_1_init + a0_2_init)
                                                j = T.axis.spatial(128, a1_0 * 4 + a1_1_init)
                                                T.reads()
                                                T.writes(O_tile[i, j])
                                                O_tile[i, j] = T.float32(0.0)
                                    for a2_0 in range(64):
                                        for ax0_ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"):
                                            for ax0_ax1_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                for ax0_ax1_fused_2 in range(1):
                                                    for ax0_ax1_fused_3 in T.vectorized(1):
                                                        with T.block("X_shared"):
                                                            i = T.axis.spatial(8, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1 + ax0_ax1_fused_2 + ax0_ax1_fused_3) // 32)
                                                            j = T.axis.spatial(2048, a2_0 * 32 + (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1 + ax0_ax1_fused_2 + ax0_ax1_fused_3) % 32)
                                                            T.reads(row[1], X[m_offset + i, j])
                                                            T.writes(X_tile[i, j])
                                                            X_tile[i, j] = T.if_then_else(m_offset + i < row[1], X[m_offset + i, j], T.float16(0.0))
                                        for ax0_ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"):
                                            for ax0_ax1_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                for ax0_ax1_fused_2 in range(16):
                                                    for ax0_ax1_fused_3 in T.vectorized(1):
                                                        with T.block("W_shared"):
                                                            i = T.axis.spatial(128, (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2 + ax0_ax1_fused_3) // 32)
                                                            j = T.axis.spatial(2048, a2_0 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2 + ax0_ax1_fused_3) % 32)
                                                            T.reads(w[e, n_offset + i, j // 8], scale[e, n_offset + i, j // 32])
                                                            T.writes(W_tile[i, j])
                                                            W_tile[i, j] = T.if_then_else(n_offset + i < 2816, (T.Cast("float16", T.bitwise_and(T.shift_right(w[e, n_offset + i, j // 8], T.Cast("uint32", j % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * scale[e, n_offset + i, j // 32], T.float16(0.0))
                                        for a2_1, a1_1, a0_1 in T.grid(32, 4, 1):
                                            for a0_2 in T.vectorized(1):
                                                with T.block("compute_update"):
                                                    i = T.axis.spatial(8, a0_0 + a0_1 + a0_2)
                                                    j = T.axis.spatial(128, a1_0 * 4 + a1_1)
                                                    k = T.axis.reduce(2048, a2_0 * 32 + a2_1)
                                                    T.reads(O_tile[i, j], X_tile[i, k], W_tile[j, k])
                                                    T.writes(O_tile[i, j])
                                                    O_tile[i, j] = O_tile[i, j] + T.Cast("float32", X_tile[i, k] * W_tile[j, k])
                                    for ax0, ax1_0 in T.grid(1, 4):
                                        for ax1_1 in T.vectorized(1):
                                            with T.block("store"):
                                                i = T.axis.spatial(8, a0_0 + ax0)
                                                j = T.axis.spatial(128, a1_0 * 4 + ax1_0 + ax1_1)
                                                T.reads(row[1], O_tile[i, j])
                                                T.writes(O[m_offset + i, n_offset + j])
                                                if m_offset + i < row[1] and n_offset + j < 2816:
                                                    O[m_offset + i, n_offset + j] = T.Cast("float16", O_tile[i, j])
                    tile_id[0] = tile_id[0] + 1024

    @T.prim_func(private=True)
    def dequantize_group_gemm1(var_x: T.handle, w: T.Buffer((60, 2048, 176), "uint32"), scale: T.Buffer((60, 2048, 44), "float16"), indptr: T.Buffer((61,), "int32"), var_o: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B = T.int32(is_size_var=True)
        X = T.match_buffer(var_x, (B, 1408), "float16")
        O = T.match_buffer(var_o, (B, 2048), "float16")
        # with T.block("root"):
        for _bx in T.thread_binding(1024, thread="blockIdx.x"):
            with T.block("CTA"):
                bx = T.axis.spatial(1024, _bx)
                T.reads(X[0:B, 0:1408], w[0:60, 0:2048, 0:176], scale[0:60, 0:2048, 0:44], indptr[0:61])
                T.writes(O[0:B, 0:2048])
                sum = T.alloc_buffer((2,), "int32", scope="local")
                row = T.alloc_buffer((2,), "int32", scope="local")
                cur_e = T.alloc_buffer((1,), "int32", scope="local")
                tile_id = T.alloc_buffer((1,), "int32", scope="local")
                sum[0] = 0
                sum[1] = (indptr[1] - indptr[0] + 8 - 1) // 8 * 16
                row[0] = 0
                row[1] = indptr[1] - indptr[0]
                cur_e[0] = 0
                tile_id[0] = bx
                while T.tvm_thread_invariant(cur_e[0] < 60):
                    while sum[1] <= tile_id[0] and cur_e[0] < 60:
                        cur_e[0] = cur_e[0] + 1
                        if cur_e[0] < 60:
                            e: T.int32 = cur_e[0]
                            delta: T.int32 = indptr[e + 1] - indptr[e]
                            sum[0] = sum[1]
                            sum[1] = sum[1] + (delta + 8 - 1) // 8 * 16
                            row[0] = row[1]
                            row[1] = row[1] + delta
                    T.tvm_storage_sync("shared")
                    if T.tvm_thread_invariant(cur_e[0] < 60):
                        e: T.int32 = cur_e[0]
                        num_tiles: T.int32 = tile_id[0] - sum[0]
                        m_offset: T.int32 = num_tiles // 16 * 8 + row[0]
                        n_offset: T.int32 = num_tiles % 16 * 128
                        with T.block("gemm"):
                            T.reads(row[1], X[m_offset:m_offset + 8, 0:1408], w[e, n_offset:n_offset + 128, 0:176], scale[e, n_offset:n_offset + 128, 0:44])
                            T.writes(O[m_offset:m_offset + 8, n_offset:n_offset + 128])
                            X_tile = T.alloc_buffer((8, 1408), "float16", scope="shared")
                            W_tile = T.alloc_buffer((128, 1408), "float16", scope="shared")
                            O_tile = T.alloc_buffer((8, 128), scope="local")
                            for a1_0 in T.thread_binding(32, thread="threadIdx.y"):
                                for a0_0 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                                    for a1_1_init, a0_1_init in T.grid(4, 1):
                                        for a0_2_init in T.vectorized(1):
                                            with T.block("compute_init"):
                                                i = T.axis.spatial(8, a0_0 + a0_1_init + a0_2_init)
                                                j = T.axis.spatial(128, a1_0 * 4 + a1_1_init)
                                                T.reads()
                                                T.writes(O_tile[i, j])
                                                O_tile[i, j] = T.float32(0.0)
                                    for a2_0 in range(44):
                                        for ax0_ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"):
                                            for ax0_ax1_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                for ax0_ax1_fused_2 in range(1):
                                                    for ax0_ax1_fused_3 in T.vectorized(1):
                                                        with T.block("X_shared"):
                                                            i = T.axis.spatial(8, (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1 + ax0_ax1_fused_2 + ax0_ax1_fused_3) // 32)
                                                            j = T.axis.spatial(1408, a2_0 * 32 + (ax0_ax1_fused_0 * 8 + ax0_ax1_fused_1 + ax0_ax1_fused_2 + ax0_ax1_fused_3) % 32)
                                                            T.reads(row[1], X[m_offset + i, j])
                                                            T.writes(X_tile[i, j])
                                                            X_tile[i, j] = T.if_then_else(m_offset + i < row[1], X[m_offset + i, j], T.float16(0.0))
                                        for ax0_ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"):
                                            for ax0_ax1_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
                                                for ax0_ax1_fused_2 in range(16):
                                                    for ax0_ax1_fused_3 in T.vectorized(1):
                                                        with T.block("W_shared"):
                                                            i = T.axis.spatial(128, (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2 + ax0_ax1_fused_3) // 32)
                                                            j = T.axis.spatial(1408, a2_0 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 16 + ax0_ax1_fused_2 + ax0_ax1_fused_3) % 32)
                                                            T.reads(w[e, n_offset + i, j // 8], scale[e, n_offset + i, j // 32])
                                                            T.writes(W_tile[i, j])
                                                            W_tile[i, j] = T.if_then_else(n_offset + i < 2048, (T.Cast("float16", T.bitwise_and(T.shift_right(w[e, n_offset + i, j // 8], T.Cast("uint32", j % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * scale[e, n_offset + i, j // 32], T.float16(0.0))
                                        for a2_1, a1_1, a0_1 in T.grid(32, 4, 1):
                                            for a0_2 in T.vectorized(1):
                                                with T.block("compute_update"):
                                                    i = T.axis.spatial(8, a0_0 + a0_1 + a0_2)
                                                    j = T.axis.spatial(128, a1_0 * 4 + a1_1)
                                                    k = T.axis.reduce(1408, a2_0 * 32 + a2_1)
                                                    T.reads(O_tile[i, j], X_tile[i, k], W_tile[j, k])
                                                    T.writes(O_tile[i, j])
                                                    O_tile[i, j] = O_tile[i, j] + T.Cast("float32", X_tile[i, k] * W_tile[j, k])
                                    for ax0, ax1_0 in T.grid(1, 4):
                                        for ax1_1 in T.vectorized(1):
                                            with T.block("store"):
                                                i = T.axis.spatial(8, a0_0 + ax0)
                                                j = T.axis.spatial(128, a1_0 * 4 + ax1_0 + ax1_1)
                                                T.reads(row[1], O_tile[i, j])
                                                T.writes(O[m_offset + i, n_offset + j])
                                                if m_offset + i < row[1] and n_offset + j < 2048:
                                                    O[m_offset + i, n_offset + j] = T.Cast("float16", O_tile[i, j])
                    tile_id[0] = tile_id[0] + 1024

    @T.prim_func
    def full(var_result: T.handle, value: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch_size = T.int32(is_size_var=True)
        result = T.match_buffer(var_result, (batch_size, 1), "int32")
        # with T.block("root"):
        for i in range(batch_size):
            with T.block("block"):
                vi = T.axis.spatial(batch_size, i)
                T.reads()
                T.writes(result[vi, 0])
                result[vi, 0] = value

    @T.prim_func(private=True)
    def fuse_add_norm_decode(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size = T.int32()
        A = T.match_buffer(pA, (batch_size, 1, 2048), "float16")
        B = T.match_buffer(pB, (batch_size, 1, 2048), "float16")
        O = T.match_buffer(pO, (batch_size, 1, 2048), "float16")
        add = T.match_buffer(pAdd, (batch_size, 1, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((batch_size, 1), scope="shared")
        sum_local = T.alloc_buffer((1024, batch_size, 1), scope="local")
        for v_bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(A[bx, 0, h], B[bx, 0, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[bx, 0, h] + B[bx, 0, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        v_ax1 = T.axis.spatial(1, 0)
                        h = T.axis.spatial(2048, i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[bx, v_ax1, h])
                        add[bx, v_ax1, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, bx, 0])
                    sum_local[tx, bx, 0] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, bx, 0], add_local[i])
                        T.writes(sum_local[tx, bx, 0])
                        sum_local[tx, bx, 0] = sum_local[tx, bx, 0] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, bx, 0])
                        T.writes(sum_shared[bx, 0])
                        with T.init():
                            sum_shared[bx, 0] = T.float32(0.0)
                        sum_shared[bx, 0] = sum_shared[bx, 0] + sum_local[tx, bx, 0]
            for i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(batch_size, v_bx)
                        h = T.axis.spatial(2048, i * 1024 + v_tx_2)
                        T.reads(sum_shared[bx, 0], add_local[h // 1024], C[h])
                        T.writes(O[bx, 0, h])
                        O[bx, 0, h] = T.Cast("float16", T.rsqrt(sum_shared[bx, 0] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[h // 1024]) * T.Cast("float32", C[h]))

    @T.prim_func(private=True)
    def fuse_add_norm_prefill(pA: T.handle, pB: T.handle, C: T.Buffer((2048,), "float16"), pO: T.handle, pAdd: T.handle):
        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        A = T.match_buffer(pA, (1, seq_len, 2048), "float16")
        B = T.match_buffer(pB, (1, seq_len, 2048), "float16")
        O = T.match_buffer(pO, (1, seq_len, 2048), "float16")
        add = T.match_buffer(pAdd, (1, seq_len, 2048), "float16")
        # with T.block("root"):
        add_local = T.alloc_buffer((2,), "float16", scope="local")
        sum_shared = T.alloc_buffer((1, seq_len), scope="shared")
        sum_local = T.alloc_buffer((1024, 1, seq_len), scope="local")
        for v_bx in T.thread_binding(seq_len, thread="blockIdx.x"):
            for v_tx in T.thread_binding(1024, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                for v_i in range(2):
                    with T.block("T_add"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(A[0, bx, h], B[0, bx, h])
                        T.writes(add_local[h // 1024])
                        add_local[h // 1024] = A[0, bx, h] + B[0, bx, h]
                    with T.block("T_write_back"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        h = T.axis.spatial(2048, v_i * 1024 + v_tx)
                        T.reads(add_local[h // 1024])
                        T.writes(add[0, bx, h])
                        add[0, bx, h] = add_local[h // 1024]
                with T.block("T_multiply_red_rf_init"):
                    tx, bx = T.axis.remap("SS", [v_tx, v_bx])
                    T.reads()
                    T.writes(sum_local[tx, 0, bx])
                    sum_local[tx, 0, bx] = T.float32(0.0)
                for v_i, _j in T.grid(2, 1):
                    with T.block("T_multiply_red_rf_update"):
                        tx, bx, i = T.axis.remap("SSR", [v_tx, v_bx, v_i])
                        T.reads(sum_local[tx, 0, bx], add_local[i])
                        T.writes(sum_local[tx, 0, bx])
                        sum_local[tx, 0, bx] = sum_local[tx, 0, bx] + T.Cast("float32", add_local[i]) * T.Cast("float32", add_local[i])
            for _j in range(1):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_multiply_red"):
                        tx, bx = T.axis.remap("RS", [v_tx_2, v_bx])
                        T.reads(sum_local[tx, 0, bx])
                        T.writes(sum_shared[0, bx])
                        with T.init():
                            sum_shared[0, bx] = T.float32(0.0)
                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]
            for v_i in range(2):
                for v_tx_2 in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("T_cast_2"):
                        bx = T.axis.spatial(seq_len, v_bx)
                        v1 = T.axis.spatial(2048, v_i * 1024 + v_tx_2)
                        T.reads(sum_shared[0, bx], add_local[v1 // 1024], C[v1])
                        T.writes(O[0, bx, v1])
                        O[0, bx, v1] = T.Cast("float16", T.rsqrt(sum_shared[0, bx] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07)) * T.Cast("float32", add_local[v1 // 1024]) * T.Cast("float32", C[v1]))

    @T.prim_func(private=True)
    def fused_NT_matmul12_cast4(reshape220: T.Buffer((T.int64(1), T.int64(2048)), "float16"), model_layers_0_mlp_gate_weight2: T.Buffer((T.int64(60), T.int64(2048)), "float16"), compute_intermediate: T.Buffer((T.int64(1), T.int64(60)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(60)), "float16")
        for i0, i1, k in T.grid(T.int64(1), T.int64(60), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(reshape220[v_i0, v_k], model_layers_0_mlp_gate_weight2[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + reshape220[v_i0, v_k] * model_layers_0_mlp_gate_weight2[v_i1, v_k]
        for i0, i1 in T.grid(T.int64(1), T.int64(60)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(NT_matmul_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1])

    @T.prim_func(private=True)
    def fused_NT_matmul15_tir_sigmoid1(reshape220: T.Buffer((T.int64(1), T.int64(2048)), "float16"), model_layers_0_mlp_shared_expert_gate_weight2: T.Buffer((T.int64(1), T.int64(2048)), "float16"), compute_intermediate: T.Buffer((T.int64(1), T.int64(1)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1)), "float16")
        for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(reshape220[v_i0, v_k], model_layers_0_mlp_shared_expert_gate_weight2[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + reshape220[v_i0, v_k] * model_layers_0_mlp_shared_expert_gate_weight2[v_i1, v_k]
        for i0, i1 in T.grid(T.int64(1), T.int64(1)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(NT_matmul_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0, v_i1])

    @T.prim_func(private=True)
    def fused_NT_matmul2_cast(p_reshape628: T.handle, model_layers_0_mlp_gate_weight4: T.Buffer((T.int64(60), T.int64(2048)), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape628 = T.match_buffer(p_reshape628, (batch_size, T.int64(2048)), "float16")
        compute_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(60)))
        # with T.block("root"):
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(60)), "float16")
        for i0, i1, k in T.grid(batch_size, T.int64(60), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(reshape628[v_i0, v_k], model_layers_0_mlp_gate_weight4[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + reshape628[v_i0, v_k] * model_layers_0_mlp_gate_weight4[v_i1, v_k]
        for i0, i1 in T.grid(batch_size, T.int64(60)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(NT_matmul_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1])

    @T.prim_func(private=True)
    def fused_NT_matmul5_tir_sigmoid(p_reshape628: T.handle, model_layers_0_mlp_shared_expert_gate_weight4: T.Buffer((T.int64(1), T.int64(2048)), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape628 = T.match_buffer(p_reshape628, (batch_size, T.int64(2048)), "float16")
        compute_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1)), "float16")
        # with T.block("root"):
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1)), "float16")
        for i0, i1, k in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(reshape628[v_i0, v_k], model_layers_0_mlp_shared_expert_gate_weight4[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + reshape628[v_i0, v_k] * model_layers_0_mlp_shared_expert_gate_weight4[v_i1, v_k]
        for i0, i1 in T.grid(batch_size, T.int64(1)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(NT_matmul_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.sigmoid(NT_matmul_intermediate[v_i0, v_i1])

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul10_add3(model_layers_0_self_attn_c_attn_q_weight2: T.Buffer((T.int64(6144), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale2: T.Buffer((T.int64(6144), T.int64(64)), "float16"), rms_norm49: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), model_layers_0_self_attn_c_attn_bias2: T.Buffer((T.int64(6144),), "float16"), T_add_intermediate_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(6144)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(6144)), "float16")
        for i0, i1 in T.grid(T.int64(6144), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(6144), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(6144), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm49[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm49[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(6144)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], model_layers_0_self_attn_c_attn_bias2[v_ax2])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + model_layers_0_self_attn_c_attn_bias2[v_ax2]

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul7_add2(model_layers_0_self_attn_c_attn_q_weight3: T.Buffer((T.int64(6144), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale3: T.Buffer((T.int64(6144), T.int64(64)), "float16"), p_rms_norm98: T.handle, model_layers_0_self_attn_c_attn_bias3: T.Buffer((T.int64(6144),), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm98 = T.match_buffer(p_rms_norm98, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(6144)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len, T.int64(6144)), "float16")
        for i0, i1 in T.grid(T.int64(6144), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(6144), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(6144), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm98[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm98[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(6144)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], model_layers_0_self_attn_c_attn_bias3[v_ax2])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + model_layers_0_self_attn_c_attn_bias3[v_ax2]

    @T.prim_func(private=True)
    def fused_dequantize1_fused_NT_matmul_add(model_layers_0_self_attn_c_attn_q_weight4: T.Buffer((T.int64(6144), T.int64(256)), "uint32"), model_layers_0_self_attn_c_attn_q_scale4: T.Buffer((T.int64(6144), T.int64(64)), "float16"), p_rms_norm147: T.handle, model_layers_0_self_attn_c_attn_bias4: T.Buffer((T.int64(6144),), "float16"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm147 = T.match_buffer(p_rms_norm147, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(6144)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(6144), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(6144)), "float16")
        for i0, i1 in T.grid(T.int64(6144), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_c_attn_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_c_attn_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(6144), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_c_attn_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_c_attn_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(6144), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm147[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm147[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(6144)):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(NT_matmul_intermediate[v_ax0, v_ax1, v_ax2], model_layers_0_self_attn_c_attn_bias4[v_ax2])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
                T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + model_layers_0_self_attn_c_attn_bias4[v_ax2]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul1(model_layers_0_self_attn_o_proj_q_weight4: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale4: T.Buffer((T.int64(2048), T.int64(64)), "float16"), p_reshape627: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape627 = T.match_buffer(p_reshape627, (batch_size, T.int64(1), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape627[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape627[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul11(model_layers_0_self_attn_o_proj_q_weight2: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale2: T.Buffer((T.int64(2048), T.int64(64)), "float16"), lv653: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(lv653[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv653[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize2_NT_matmul8(model_layers_0_self_attn_o_proj_q_weight3: T.Buffer((T.int64(2048), T.int64(256)), "uint32"), model_layers_0_self_attn_o_proj_q_scale3: T.Buffer((T.int64(2048), T.int64(64)), "float16"), p_reshape411: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape411 = T.match_buffer(p_reshape411, (T.int64(1), seq_len, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_self_attn_o_proj_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_self_attn_o_proj_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_self_attn_o_proj_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_self_attn_o_proj_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(2048), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(reshape411[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + reshape411[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul13(model_layers_0_mlp_shared_expert_gate_up_proj_q_weight2: T.Buffer((T.int64(11264), T.int64(256)), "uint32"), model_layers_0_mlp_shared_expert_gate_up_proj_q_scale2: T.Buffer((T.int64(11264), T.int64(64)), "float16"), lv655: T.Buffer((T.int64(1), T.int64(2048)), "float16"), NT_matmul_intermediate: T.Buffer((T.int64(1), T.int64(11264)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(11264), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(11264), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(11264), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_shared_expert_gate_up_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_shared_expert_gate_up_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(11264), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_shared_expert_gate_up_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_shared_expert_gate_up_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, k in T.grid(T.int64(1), T.int64(11264), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv655[v_i0, v_k], dequantize_intermediate[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + lv655[v_i0, v_k] * dequantize_intermediate[v_i1, v_k]

    @T.prim_func(private=True)
    def fused_dequantize3_NT_matmul3(model_layers_0_mlp_shared_expert_gate_up_proj_q_weight4: T.Buffer((T.int64(11264), T.int64(256)), "uint32"), model_layers_0_mlp_shared_expert_gate_up_proj_q_scale4: T.Buffer((T.int64(11264), T.int64(64)), "float16"), p_reshape628: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape628 = T.match_buffer(p_reshape628, (batch_size, T.int64(2048)), "float16")
        NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(11264)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(11264), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(11264), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(11264), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_shared_expert_gate_up_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_shared_expert_gate_up_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(11264), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_shared_expert_gate_up_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_shared_expert_gate_up_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, k in T.grid(batch_size, T.int64(11264), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(reshape628[v_i0, v_k], dequantize_intermediate[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + reshape628[v_i0, v_k] * dequantize_intermediate[v_i1, v_k]

    @T.prim_func(private=True)
    def fused_dequantize4_fused_NT_matmul14_multiply7_add4(model_layers_0_mlp_shared_expert_down_proj_q_weight2: T.Buffer((T.int64(2048), T.int64(704)), "uint32"), model_layers_0_mlp_shared_expert_down_proj_q_scale2: T.Buffer((T.int64(2048), T.int64(176)), "float16"), lv662: T.Buffer((T.int64(1), T.int64(5632)), "float16"), lv663: T.Buffer((T.int64(1), T.int64(1)), "float16"), lv661: T.Buffer((T.int64(1), T.int64(2048)), "float16"), T_add_intermediate_intermediate: T.Buffer((T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(5632)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(5632)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(2048)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(5632)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_shared_expert_down_proj_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_shared_expert_down_proj_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(5632)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_shared_expert_down_proj_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_shared_expert_down_proj_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, k in T.grid(T.int64(1), T.int64(2048), T.int64(5632)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv662[v_i0, v_k], dequantize_intermediate[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + lv662[v_i0, v_k] * dequantize_intermediate[v_i1, v_k]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv663[v_ax0, T.int64(0)], NT_matmul_intermediate[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1])
                T_multiply_intermediate[v_ax0, v_ax1] = lv663[v_ax0, T.int64(0)] * NT_matmul_intermediate[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(2048)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv661[v_ax0, v_ax1], T_multiply_intermediate[v_ax0, v_ax1])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1])
                T_add_intermediate_intermediate[v_ax0, v_ax1] = lv661[v_ax0, v_ax1] + T_multiply_intermediate[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def fused_dequantize4_fused_NT_matmul4_multiply3_add1(model_layers_0_mlp_shared_expert_down_proj_q_weight4: T.Buffer((T.int64(2048), T.int64(704)), "uint32"), model_layers_0_mlp_shared_expert_down_proj_q_scale4: T.Buffer((T.int64(2048), T.int64(176)), "float16"), p_lv6: T.handle, p_lv7: T.handle, p_lv5: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv6 = T.match_buffer(p_lv6, (batch_size, T.int64(5632)), "float16")
        lv7 = T.match_buffer(p_lv7, (batch_size, T.int64(1)), "float16")
        lv5 = T.match_buffer(p_lv5, (batch_size, T.int64(2048)), "float16")
        T_add_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(2048)), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(2048), T.int64(5632)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(2048), T.int64(5632)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(2048)), "float16")
        T_multiply_intermediate = T.alloc_buffer((batch_size, T.int64(2048)), "float16")
        for i0, i1 in T.grid(T.int64(2048), T.int64(5632)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_layers_0_mlp_shared_expert_down_proj_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_layers_0_mlp_shared_expert_down_proj_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(2048), T.int64(5632)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], model_layers_0_mlp_shared_expert_down_proj_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * model_layers_0_mlp_shared_expert_down_proj_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, k in T.grid(batch_size, T.int64(2048), T.int64(5632)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(lv6[v_i0, v_k], dequantize_intermediate[v_i1, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1] = NT_matmul_intermediate[v_i0, v_i1] + lv6[v_i0, v_k] * dequantize_intermediate[v_i1, v_k]
        for ax0, ax1 in T.grid(batch_size, T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv7[v_ax0, T.int64(0)], NT_matmul_intermediate[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1])
                T_multiply_intermediate[v_ax0, v_ax1] = lv7[v_ax0, T.int64(0)] * NT_matmul_intermediate[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(batch_size, T.int64(2048)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv5[v_ax0, v_ax1], T_multiply_intermediate[v_ax0, v_ax1])
                T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1])
                T_add_intermediate_intermediate[v_ax0, v_ax1] = lv5[v_ax0, v_ax1] + T_multiply_intermediate[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def fused_dequantize_fused_NT_matmul16_cast6(lm_head_q_weight2: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), lm_head_q_scale2: T.Buffer((T.int64(151936), T.int64(64)), "float16"), rms_norm97: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), compute_intermediate_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(151936)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(151936)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lm_head_q_weight2[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lm_head_q_weight2[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], lm_head_q_scale2[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * lm_head_q_scale2[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm97[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm97[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(151936)):
            with T.block("compute_1"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute_intermediate_intermediate[v_i0, v_i1, v_i2])
                compute_intermediate_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2])

    @T.prim_func(private=True)
    def fused_dequantize_fused_NT_matmul6_cast2(lm_head_q_weight4: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), lm_head_q_scale4: T.Buffer((T.int64(151936), T.int64(64)), "float16"), p_rms_norm195: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm195 = T.match_buffer(p_rms_norm195, (batch_size, T.int64(1), T.int64(2048)), "float16")
        compute_intermediate_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(151936)))
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((batch_size, T.int64(1), T.int64(151936)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lm_head_q_weight4[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lm_head_q_weight4[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], lm_head_q_scale4[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * lm_head_q_scale4[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(rms_norm195[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm195[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for i0, i1, i2 in T.grid(batch_size, T.int64(1), T.int64(151936)):
            with T.block("compute_1"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute_intermediate_intermediate[v_i0, v_i1, v_i2])
                compute_intermediate_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2])

    @T.prim_func(private=True)
    def fused_dequantize_fused_NT_matmul9_cast3(lm_head_q_weight3: T.Buffer((T.int64(151936), T.int64(256)), "uint32"), lm_head_q_scale3: T.Buffer((T.int64(151936), T.int64(64)), "float16"), p_take49: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        take49 = T.match_buffer(p_take49, (T.int64(1), batch_size, T.int64(2048)), "float16")
        compute_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), batch_size, T.int64(151936)))
        # with T.block("root"):
        compute = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        dequantize_intermediate = T.alloc_buffer((T.int64(151936), T.int64(2048)), "float16")
        NT_matmul_intermediate = T.alloc_buffer((T.int64(1), batch_size, T.int64(151936)), "float16")
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(lm_head_q_weight3[v_i0, v_i1 // T.int64(8)])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(lm_head_q_weight3[v_i0, v_i1 // T.int64(8)], T.Cast("uint32", v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
        for i0, i1 in T.grid(T.int64(151936), T.int64(2048)):
            with T.block("dequantize"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(compute[v_i0, v_i1], lm_head_q_scale3[v_i0, v_i1 // T.int64(32)])
                T.writes(dequantize_intermediate[v_i0, v_i1])
                dequantize_intermediate[v_i0, v_i1] = (compute[v_i0, v_i1] - T.float16(7.0)) * lm_head_q_scale3[v_i0, v_i1 // T.int64(32)]
        for i0, i1, i2, k in T.grid(T.int64(1), batch_size, T.int64(151936), T.int64(2048)):
            with T.block("NT_matmul"):
                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
                T.reads(take49[v_i0, v_i1, v_k], dequantize_intermediate[v_i2, v_k])
                T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                with T.init():
                    NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0)
                NT_matmul_intermediate[v_i0, v_i1, v_i2] = NT_matmul_intermediate[v_i0, v_i1, v_i2] + take49[v_i0, v_i1, v_k] * dequantize_intermediate[v_i2, v_k]
        for i0, i1, i2 in T.grid(T.int64(1), batch_size, T.int64(151936)):
            with T.block("compute_1"):
                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(NT_matmul_intermediate[v_i0, v_i1, v_i2])
                T.writes(compute_intermediate_intermediate[v_i0, v_i1, v_i2])
                compute_intermediate_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", NT_matmul_intermediate[v_i0, v_i1, v_i2])

    @T.prim_func(private=True)
    def fused_dequantize_take2(model_embed_tokens_q_weight: T.Buffer((151936, 256), "uint32"), model_embed_tokens_q_scale: T.Buffer((151936, 64), "float16"), p_input_ids: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int32()
        input_ids = T.match_buffer(p_input_ids, (seq_len,), "int32")
        T_take_intermediate = T.match_buffer(p_output0, (seq_len, 2048), "float16")
        # with T.block("root"):
        compute = T.alloc_buffer((151936, 2048), "float16")
        for i0, i1 in T.grid(151936, 2048):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(model_embed_tokens_q_weight[v_i0, v_i1 // 8])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.Cast("float16", T.bitwise_and(T.shift_right(model_embed_tokens_q_weight[v_i0, v_i1 // 8], T.Cast("uint32", v_i1 % 8 * 4)), T.uint32(15)))
        for ax0, ax1 in T.grid(seq_len, 2048):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(input_ids[v_ax0], compute[input_ids[v_ax0], v_ax1], model_embed_tokens_q_scale[input_ids[v_ax0], v_ax1 // 32])
                T.writes(T_take_intermediate[v_ax0, v_ax1])
                T_take_intermediate[v_ax0, v_ax1] = (compute[input_ids[v_ax0], v_ax1] - T.float16(7.0)) * model_embed_tokens_q_scale[input_ids[v_ax0], v_ax1 // 32]

    @T.prim_func(private=True)
    def fused_expert_mask_transpose(p_top4_softmax_172: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        top4_softmax_172 = T.match_buffer(p_top4_softmax_172, (batch_size, T.int64(4)), "int32")
        T_transpose_intermediate = T.match_buffer(p_output0, (T.int64(60), batch_size), "int32")
        # with T.block("root"):
        compute_intermediate = T.alloc_buffer((batch_size, T.int64(60)), "int32")
        for i, j in T.grid(batch_size, T.int64(60)):
            with T.block("compute"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(top4_softmax_172[v_i, T.int64(0):T.int64(4)])
                T.writes(compute_intermediate[v_i, v_j])
                compute_intermediate[v_i, v_j] = T.Select(T.Cast("int64", top4_softmax_172[v_i, T.int64(0)]) == v_j or T.Cast("int64", top4_softmax_172[v_i, T.int64(1)]) == v_j or T.Cast("int64", top4_softmax_172[v_i, T.int64(2)]) == v_j or T.Cast("int64", top4_softmax_172[v_i, T.int64(3)]) == v_j, 1, 0)
        for ax0, ax1 in T.grid(T.int64(60), batch_size):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(compute_intermediate[v_ax1, v_ax0])
                T.writes(T_transpose_intermediate[v_ax0, v_ax1])
                T_transpose_intermediate[v_ax0, v_ax1] = compute_intermediate[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def fused_multiply1_sum(p_reshape631: T.handle, p_reshape630: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape631 = T.match_buffer(p_reshape631, (batch_size, T.int64(4), T.int64(2048)), "float16")
        reshape630 = T.match_buffer(p_reshape630, (batch_size, T.int64(4), T.int64(1)), "float16")
        mul289_red_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(2048)), "float16")
        # with T.block("root"):
        T_multiply_intermediate = T.alloc_buffer((batch_size, T.int64(4), T.int64(2048)), "float16")
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(4), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape631[v_ax0, v_ax1, v_ax2], reshape630[v_ax0, v_ax1, T.int64(0)])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = reshape631[v_ax0, v_ax1, v_ax2] * reshape630[v_ax0, v_ax1, T.int64(0)]
        for ax0, ax1, k1 in T.grid(batch_size, T.int64(2048), T.int64(4)):
            with T.block("mul289_red"):
                v_ax0, v_ax1, v_k1 = T.axis.remap("SSR", [ax0, ax1, k1])
                T.reads(T_multiply_intermediate[v_ax0, v_k1, v_ax1])
                T.writes(mul289_red_intermediate[v_ax0, v_ax1])
                with T.init():
                    mul289_red_intermediate[v_ax0, v_ax1] = T.float16(0.0)
                mul289_red_intermediate[v_ax0, v_ax1] = mul289_red_intermediate[v_ax0, v_ax1] + T_multiply_intermediate[v_ax0, v_k1, v_ax1]

    @T.prim_func(private=True)
    def fused_multiply5_sum1(reshape222: T.Buffer((T.int64(1), T.int64(4), T.int64(2048)), "float16"), reshape221: T.Buffer((T.int64(1), T.int64(4), T.int64(1)), "float16"), mul97_red_intermediate: T.Buffer((T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(4), T.int64(2048)), "float16")
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape222[v_ax0, v_ax1, v_ax2], reshape221[v_ax0, v_ax1, T.int64(0)])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
                T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = reshape222[v_ax0, v_ax1, v_ax2] * reshape221[v_ax0, v_ax1, T.int64(0)]
        for ax0, ax1, k1 in T.grid(T.int64(1), T.int64(2048), T.int64(4)):
            with T.block("mul97_red"):
                v_ax0, v_ax1, v_k1 = T.axis.remap("SSR", [ax0, ax1, k1])
                T.reads(T_multiply_intermediate[v_ax0, v_k1, v_ax1])
                T.writes(mul97_red_intermediate[v_ax0, v_ax1])
                with T.init():
                    mul97_red_intermediate[v_ax0, v_ax1] = T.float16(0.0)
                mul97_red_intermediate[v_ax0, v_ax1] = mul97_red_intermediate[v_ax0, v_ax1] + T_multiply_intermediate[v_ax0, v_k1, v_ax1]

    @T.prim_func(private=True)
    def fused_reshape15_reshape16(add96: T.Buffer((T.int64(1), T.int64(1), T.int64(6144)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(48), T.int64(128)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(48), T.int64(128)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(48), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(add96[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(6144)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = add96[T.int64(0), T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(6144)]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(48), T.int64(128)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(48), v_ax2 % T.int64(128)])
                T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(48), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def fused_reshape17_reshape18(lv292: T.Buffer((T.int64(1), T.int64(16), T.int64(128)), "float16"), T_reshape_intermediate_1: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_reshape_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(16), T.int64(128)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(16), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv292[T.int64(0), (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = lv292[T.int64(0), (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_reshape_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)])
                T.writes(T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate_1[v_ax0, v_ax1, v_ax2] = T_reshape_intermediate[T.int64(0), T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def fused_reshape19(lv288_0: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv288_0[T.int64(0), T.int64(0), v_ax1 % T.int64(2048)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1])
                T_reshape_intermediate[v_ax0, v_ax1] = lv288_0[T.int64(0), T.int64(0), v_ax1 % T.int64(2048)]

    @T.prim_func(private=True)
    def fused_reshape20(lv294_0: T.Buffer((T.int64(1), T.int64(4)), "float16"), T_reshape_intermediate: T.Buffer((T.int64(1), T.int64(4), T.int64(1)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(1)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv294_0[T.int64(0), (v_ax1 + v_ax2) % T.int64(4)])
                T.writes(T_reshape_intermediate[v_ax0, v_ax1, v_ax2])
                T_reshape_intermediate[v_ax0, v_ax1, v_ax2] = lv294_0[T.int64(0), (v_ax1 + v_ax2) % T.int64(4)]

    @T.prim_func
    def fused_rope(var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, apply_rope: T.int32):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        seq_len = T.int32()
        qkv = T.match_buffer(var_qkv, (seq_len, 48, 128), "float16")
        position_map = T.match_buffer(var_position_map, (seq_len,), "int32", offset_factor=1)
        q = T.match_buffer(var_q, (seq_len, 16, 128), "float16")
        k = T.match_buffer(var_k, (seq_len, 16, 128), "float16")
        v = T.match_buffer(var_v, (seq_len, 16, 128), "float16")
        # with T.block("root"):
        for iters_0, iters_1, iters_2 in T.grid(seq_len, 48, 128):
            with T.block("llama_fused_rope"):
                s, h, d = T.axis.remap("SSS", [iters_0, iters_1, iters_2])
                T.reads(position_map[s], qkv[s, h, d - 64:d - 64 + 129])
                T.writes(q[s, h, d], k[s, h - 16, d], v[s, h - 32, d])
                if h < 16:
                    freq = T.float32()
                    q[s, h, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                else:
                    if h < 32:
                        freq = T.float32()
                        k[s, h - 16, d] = T.if_then_else(apply_rope > 0 and d < 128, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", qkv[s, h, d]) + T.sin(freq) * T.Cast("float32", T.if_then_else(d < 64, qkv[s, h, d + 64] * T.float16(-1.0), qkv[s, h, d - 64]))), where={freq: T.Cast("float32", position_map[s]) / T.pow(T.float32(1000000.0), T.Cast("float32", d * 2 % 128) / T.float32(128.0))}), qkv[s, h, d])
                    else:
                        v[s, h - 32, d] = qkv[s, h, d]

    @T.prim_func(private=True)
    def fused_softmax1_cast5(astype49: T.Buffer((T.int64(1), T.int64(60)), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(60)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_softmax_maxelem = T.alloc_buffer((T.int64(1),))
        T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(60)))
        T_softmax_expsum = T.alloc_buffer((T.int64(1),))
        T_softmax_norm_intermediate = T.alloc_buffer((T.int64(1), T.int64(60)))
        for i0, k in T.grid(T.int64(1), T.int64(60)):
            with T.block("T_softmax_maxelem"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(astype49[v_i0, v_k])
                T.writes(T_softmax_maxelem[v_i0])
                with T.init():
                    T_softmax_maxelem[v_i0] = T.float32(-340282346638528859811704183484516925440.0)
                T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], astype49[v_i0, v_k])
        for i0, i1 in T.grid(T.int64(1), T.int64(60)):
            with T.block("T_softmax_exp"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(astype49[v_i0, v_i1], T_softmax_maxelem[v_i0])
                T.writes(T_softmax_exp[v_i0, v_i1])
                T_softmax_exp[v_i0, v_i1] = T.exp(astype49[v_i0, v_i1] - T_softmax_maxelem[v_i0])
        for i0, k in T.grid(T.int64(1), T.int64(60)):
            with T.block("T_softmax_expsum"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(T_softmax_exp[v_i0, v_k])
                T.writes(T_softmax_expsum[v_i0])
                with T.init():
                    T_softmax_expsum[v_i0] = T.float32(0.0)
                T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k]
        for i0, i1 in T.grid(T.int64(1), T.int64(60)):
            with T.block("T_softmax_norm"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0])
                T.writes(T_softmax_norm_intermediate[v_i0, v_i1])
                T.block_attr({"axis": 1})
                T_softmax_norm_intermediate[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0]
        for i0, i1 in T.grid(T.int64(1), T.int64(60)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_softmax_norm_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.Cast("float16", T_softmax_norm_intermediate[v_i0, v_i1])

    @T.prim_func(private=True)
    def fused_softmax_cast1(p_astype147: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        astype147 = T.match_buffer(p_astype147, (batch_size, T.int64(60)))
        compute_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(60)), "float16")
        # with T.block("root"):
        T_softmax_maxelem = T.alloc_buffer((batch_size,))
        T_softmax_exp = T.alloc_buffer((batch_size, T.int64(60)))
        T_softmax_expsum = T.alloc_buffer((batch_size,))
        T_softmax_norm_intermediate = T.alloc_buffer((batch_size, T.int64(60)))
        for i0, k in T.grid(batch_size, T.int64(60)):
            with T.block("T_softmax_maxelem"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(astype147[v_i0, v_k])
                T.writes(T_softmax_maxelem[v_i0])
                with T.init():
                    T_softmax_maxelem[v_i0] = T.float32(-340282346638528859811704183484516925440.0)
                T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], astype147[v_i0, v_k])
        for i0, i1 in T.grid(batch_size, T.int64(60)):
            with T.block("T_softmax_exp"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(astype147[v_i0, v_i1], T_softmax_maxelem[v_i0])
                T.writes(T_softmax_exp[v_i0, v_i1])
                T_softmax_exp[v_i0, v_i1] = T.exp(astype147[v_i0, v_i1] - T_softmax_maxelem[v_i0])
        for i0, k in T.grid(batch_size, T.int64(60)):
            with T.block("T_softmax_expsum"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(T_softmax_exp[v_i0, v_k])
                T.writes(T_softmax_expsum[v_i0])
                with T.init():
                    T_softmax_expsum[v_i0] = T.float32(0.0)
                T_softmax_expsum[v_i0] = T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k]
        for i0, i1 in T.grid(batch_size, T.int64(60)):
            with T.block("T_softmax_norm"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0])
                T.writes(T_softmax_norm_intermediate[v_i0, v_i1])
                T.block_attr({"axis": 1})
                T_softmax_norm_intermediate[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0]
        for i0, i1 in T.grid(batch_size, T.int64(60)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_softmax_norm_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.Cast("float16", T_softmax_norm_intermediate[v_i0, v_i1])

    @T.prim_func(private=True)
    def fused_split1_silu1_multiply2(p_lv3: T.handle, p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv3 = T.match_buffer(p_lv3, (batch_size, T.int64(11264)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (batch_size, T.int64(5632)), "float16")
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((batch_size, T.int64(5632)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((batch_size, T.int64(5632)), "float16")
        compute = T.alloc_buffer((batch_size, T.int64(5632)), "float16")
        T_multiply_intermediate = T.alloc_buffer((batch_size, T.int64(5632)), "float16")
        for ax0, ax1 in T.grid(batch_size, T.int64(5632)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv3[v_ax0, v_ax1])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1])
                T_split_sections_intermediate[v_ax0, v_ax1] = lv3[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(batch_size, T.int64(5632)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv3[v_ax0, v_ax1 + T.int64(5632)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1])
                T_split_sections_intermediate_1[v_ax0, v_ax1] = lv3[v_ax0, v_ax1 + T.int64(5632)]
        for i0, i1 in T.grid(batch_size, T.int64(5632)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_split_sections_intermediate[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1])
        for ax0, ax1 in T.grid(batch_size, T.int64(5632)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1], compute[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1])
                T_multiply_intermediate[v_ax0, v_ax1] = T_split_sections_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(batch_size, T.int64(5632)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1], T_split_sections_intermediate_1[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1])
                T_multiply_intermediate_1[v_ax0, v_ax1] = T_multiply_intermediate[v_ax0, v_ax1] * T_split_sections_intermediate_1[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def fused_split2_silu2_multiply4(lv295: T.Buffer((T.int64(4), T.int64(2816)), "float16"), T_multiply_intermediate_1: T.Buffer((T.int64(4), T.int64(1408)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((T.int64(4), T.int64(1408)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((T.int64(4), T.int64(1408)), "float16")
        compute = T.alloc_buffer((T.int64(4), T.int64(1408)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(4), T.int64(1408)), "float16")
        for ax0, ax1 in T.grid(T.int64(4), T.int64(1408)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv295[v_ax0, v_ax1])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1])
                T_split_sections_intermediate[v_ax0, v_ax1] = lv295[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(T.int64(4), T.int64(1408)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv295[v_ax0, v_ax1 + T.int64(1408)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1])
                T_split_sections_intermediate_1[v_ax0, v_ax1] = lv295[v_ax0, v_ax1 + T.int64(1408)]
        for i0, i1 in T.grid(T.int64(4), T.int64(1408)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_split_sections_intermediate[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1])
        for ax0, ax1 in T.grid(T.int64(4), T.int64(1408)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1], compute[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1])
                T_multiply_intermediate[v_ax0, v_ax1] = T_split_sections_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(T.int64(4), T.int64(1408)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1], T_split_sections_intermediate_1[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1])
                T_multiply_intermediate_1[v_ax0, v_ax1] = T_multiply_intermediate[v_ax0, v_ax1] * T_split_sections_intermediate_1[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def fused_split3_silu3_multiply6(lv438: T.Buffer((T.int64(1), T.int64(11264)), "float16"), T_multiply_intermediate_1: T.Buffer((T.int64(1), T.int64(5632)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((T.int64(1), T.int64(5632)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(5632)), "float16")
        compute = T.alloc_buffer((T.int64(1), T.int64(5632)), "float16")
        T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(5632)), "float16")
        for ax0, ax1 in T.grid(T.int64(1), T.int64(5632)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv438[v_ax0, v_ax1])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1])
                T_split_sections_intermediate[v_ax0, v_ax1] = lv438[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(5632)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv438[v_ax0, v_ax1 + T.int64(5632)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1])
                T_split_sections_intermediate_1[v_ax0, v_ax1] = lv438[v_ax0, v_ax1 + T.int64(5632)]
        for i0, i1 in T.grid(T.int64(1), T.int64(5632)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_split_sections_intermediate[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1])
        for ax0, ax1 in T.grid(T.int64(1), T.int64(5632)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1], compute[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1])
                T_multiply_intermediate[v_ax0, v_ax1] = T_split_sections_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(5632)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1], T_split_sections_intermediate_1[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1])
                T_multiply_intermediate_1[v_ax0, v_ax1] = T_multiply_intermediate[v_ax0, v_ax1] * T_split_sections_intermediate_1[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def fused_split_silu_multiply(p_lv780: T.handle, p_output0: T.handle, batch_size: T.int64):
        T.func_attr({"tir.noalias": T.bool(True)})
        lv780 = T.match_buffer(p_lv780, (batch_size * T.int64(4), T.int64(2816)), "float16")
        T_multiply_intermediate_1 = T.match_buffer(p_output0, (batch_size * T.int64(4), T.int64(1408)), "float16")
        # with T.block("root"):
        T_split_sections_intermediate = T.alloc_buffer((batch_size * T.int64(4), T.int64(1408)), "float16")
        T_split_sections_intermediate_1 = T.alloc_buffer((batch_size * T.int64(4), T.int64(1408)), "float16")
        compute = T.alloc_buffer((batch_size * T.int64(4), T.int64(1408)), "float16")
        T_multiply_intermediate = T.alloc_buffer((batch_size * T.int64(4), T.int64(1408)), "float16")
        for ax0, ax1 in T.grid(batch_size * T.int64(4), T.int64(1408)):
            with T.block("T_split_sections"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv780[v_ax0, v_ax1])
                T.writes(T_split_sections_intermediate[v_ax0, v_ax1])
                T_split_sections_intermediate[v_ax0, v_ax1] = lv780[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(batch_size * T.int64(4), T.int64(1408)):
            with T.block("T_split_sections_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(lv780[v_ax0, v_ax1 + T.int64(1408)])
                T.writes(T_split_sections_intermediate_1[v_ax0, v_ax1])
                T_split_sections_intermediate_1[v_ax0, v_ax1] = lv780[v_ax0, v_ax1 + T.int64(1408)]
        for i0, i1 in T.grid(batch_size * T.int64(4), T.int64(1408)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_split_sections_intermediate[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.sigmoid(T_split_sections_intermediate[v_i0, v_i1])
        for ax0, ax1 in T.grid(batch_size * T.int64(4), T.int64(1408)):
            with T.block("T_multiply"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_split_sections_intermediate[v_ax0, v_ax1], compute[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate[v_ax0, v_ax1])
                T_multiply_intermediate[v_ax0, v_ax1] = T_split_sections_intermediate[v_ax0, v_ax1] * compute[v_ax0, v_ax1]
        for ax0, ax1 in T.grid(batch_size * T.int64(4), T.int64(1408)):
            with T.block("T_multiply_1"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_intermediate[v_ax0, v_ax1], T_split_sections_intermediate_1[v_ax0, v_ax1])
                T.writes(T_multiply_intermediate_1[v_ax0, v_ax1])
                T_multiply_intermediate_1[v_ax0, v_ax1] = T_multiply_intermediate[v_ax0, v_ax1] * T_split_sections_intermediate_1[v_ax0, v_ax1]

    @T.prim_func
    def gather_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (m, n))
        batch_size = T.int32(is_size_var=True)
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        dst = T.match_buffer(var_dst, (batch_size, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("gather_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[indices[vb], vj], indices[vb])
                T.writes(dst[vb, vj])
                dst[vb, vj] = src[indices[vb], vj]

    @T.prim_func(private=True)
    def get_expert_instance_indptr(var_cumsum: T.handle, indptr: T.Buffer((61,), "int32"), batch_size: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        cumsum = T.match_buffer(var_cumsum, (batch_size * T.int64(60),), "int32")
        # with T.block("root"):
        for vi in range(61):
            with T.block("indptr"):
                i = T.axis.spatial(61, vi)
                T.reads(cumsum[T.Cast("int64", i) * batch_size - T.int64(1)])
                T.writes(indptr[i])
                indptr[i] = T.Select(i > 0, cumsum[T.Cast("int64", i) * batch_size - T.int64(1)], 0)

    @T.prim_func(private=True)
    def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        indices = T.match_buffer(B, (batch, vocab_size), "int32")
        renorm_prob = T.match_buffer(C, (batch, 1))
        out_batch = T.int64(is_size_var=True)
        usample = T.match_buffer(D, (out_batch, 1))
        sample_indices = T.match_buffer(E, (out_batch, 1), "int32")
        output_index = T.match_buffer(F, (out_batch, 1), "int32")
        # with T.block("root"):
        for ax0, ax1 in T.grid(out_batch, vocab_size):
            with T.block("T_get_index_from_sorted"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))])
                T.writes(output_index[v_ax0, 0])
                if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size:
                    if v_ax1 == T.int64(0):
                        output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
                    else:
                        if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]:
                            output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]

    @T.prim_func(private=True)
    def get_indices(var_cumsum: T.handle, var_expert_indices: T.handle, var_reverse_indices: T.handle, var_token_indices: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        cumsum_len = T.int32(is_size_var=True)
        cumsum = T.match_buffer(var_cumsum, (cumsum_len,), "int32")
        batch_size = T.int32(is_size_var=True)
        expert_indices = T.match_buffer(var_expert_indices, (batch_size, 4), "int32")
        reverse_indices = T.match_buffer(var_reverse_indices, (batch_size * 4,), "int32")
        token_indices = T.match_buffer(var_token_indices, (batch_size * 4,), "int32")
        # with T.block("root"):
        for bj_o in T.thread_binding((batch_size * 4 + 1023) // 1024, thread="blockIdx.x"):
            for bj_i in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("indices"):
                    T.reads(expert_indices[0:batch_size, 0:4], cumsum[0:cumsum_len])
                    T.writes(reverse_indices[0:batch_size * 4], token_indices[0:batch_size * 4])
                    if bj_o * 1024 + bj_i < batch_size * 4:
                        b: T.int32 = (bj_o * 1024 + bj_i) // 4
                        j: T.int32 = (bj_o * 1024 + bj_i) % 4
                        e: T.int32 = expert_indices[b, j]
                        reverse_indices[cumsum[e * batch_size + b] - 1] = b * 4 + j
                        token_indices[cumsum[e * batch_size + b] - 1] = b

    @T.prim_func(private=True)
    def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        cumsum_sorted = T.match_buffer(A, (batch, vocab_size))
        top_p = T.match_buffer(B, (batch, 1))
        top_k = T.match_buffer(C, (batch, 1), "int32")
        renorm_prob = T.match_buffer(D, (batch, 1))
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch, vocab_size):
            with T.block("T_get_renorm_prob"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0])
                T.writes(renorm_prob[v_ax0, 0])
                if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > 1):
                    renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
                else:
                    if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0]):
                        if v_ax1 + T.int64(1) == vocab_size:
                            renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
                        else:
                            if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < T.Cast("int64", top_k[v_ax0, 0])):
                                renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)]

    @T.prim_func(private=True)
    def gpu_2d_continuous_cumsum(var_a: T.handle, var_out: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        m, n = T.int64(), T.int64()
        A = T.match_buffer(var_a, (m, n), "int32")
        Out = T.match_buffer(var_out, (m, n), "int32")
        # with T.block("root"):
        Tmp = T.alloc_buffer((m, n), "int32")
        ceil_log2: T.int64 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
        total_rounds: T.int64 = ceil_log2 // T.int64(9)
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                with T.block(""):
                    T.reads(A[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)])
                    T.writes(Out[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)], Tmp[by, bx])
                    local_buf = T.alloc_buffer((4,), "int32", scope="local")
                    shared_buf = T.alloc_buffer((T.int64(512),), "int32", scope="shared")
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                            for i in T.vectorized(T.int64(4)):
                                local_buf[i] = T.if_then_else(tx_idx + i < n, T.Cast("int32", A[by, tx_idx + i]), T.Cast("int32", 0))
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                local_buf[i] = local_buf[i] + local_buf[i - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                shared_buf[ty * T.int64(128) + tx * T.int64(4) + i] = local_buf[i]
                            for i in T.unroll(T.int64(5)):
                                for j in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                    if tx >= T.shift_left(T.int64(1), i):
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                for j in T.vectorized(T.int64(4)):
                                    if ty == T.int64(0):
                                        idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i * T.int64(128) - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i
                                if bx * T.int64(512) + idx < n:
                                    Out[by, bx * T.int64(512) + idx] = shared_buf[idx]
                            if tx == T.int64(0) and ty == T.int64(0):
                                for i in T.vectorized(T.int64(4)):
                                    Tmp[by, bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds):
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    with T.block(""):
                        T.reads(Tmp[by, bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)):bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(512)])
                        T.writes(Tmp[by, T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx):T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + (T.max(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(511), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + T.int64(1) - T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx))])
                        local_buf = T.alloc_buffer((4,), "int32", scope="local")
                        shared_buf = T.alloc_buffer((T.int64(512),), "int32", scope="shared")
                        for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                                for i_1 in T.vectorized(T.int64(4)):
                                    local_buf[i_1] = T.if_then_else(tx_idx + i_1 < cur_len, T.Cast("int32", Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + tx_idx + i_1]), T.Cast("int32", 0))
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    local_buf[i_1] = local_buf[i_1] + local_buf[i_1 - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    shared_buf[ty * T.int64(128) + tx * T.int64(4) + i_1] = local_buf[i_1]
                                for i_1 in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i_1):
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i_1) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i_1 * T.int64(128) + tx * T.int64(4)
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i_1 * T.int64(128) - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i_1
                                    if bx * T.int64(512) + idx < cur_len:
                                        Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx * T.int64(512) + idx] = shared_buf[idx]
                                if tx == T.int64(0) and ty == T.int64(0):
                                    for i_1 in T.vectorized(T.int64(4)):
                                        Tmp[by, (i + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds - T.int64(1)):
            real_idx: T.int64 = total_rounds - T.int64(1) - i - T.int64(1)
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            for i_1 in range(T.int64(4)):
                                idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i_1 * T.int64(32) + tx
                                if idx < cur_len:
                                    Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] = Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] + T.if_then_else(bx > T.int64(0), Tmp[by, (real_idx + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx - T.int64(1)], 0)
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        for i in range(T.int64(4)):
                            idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i * T.int64(32) + tx
                            if idx < n:
                                Out[by, idx] = Out[by, idx] + T.if_then_else(bx > T.int64(0), Tmp[by, bx - T.int64(1)], 0)

    @T.prim_func(private=True)
    def gpu_2d_continuous_cumsum1(var_a: T.handle, var_out: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        m, n = T.int64(), T.int64()
        A = T.match_buffer(var_a, (m, n))
        Out = T.match_buffer(var_out, (m, n))
        # with T.block("root"):
        Tmp = T.alloc_buffer((m, n))
        ceil_log2: T.int64 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
        total_rounds: T.int64 = ceil_log2 // T.int64(9)
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                with T.block(""):
                    T.reads(A[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)])
                    T.writes(Out[by, bx * T.int64(512):bx * T.int64(512) + T.int64(512)], Tmp[by, bx])
                    local_buf = T.alloc_buffer((4,), scope="local")
                    shared_buf = T.alloc_buffer((T.int64(512),), scope="shared")
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                            for i in T.vectorized(T.int64(4)):
                                local_buf[i] = T.if_then_else(tx_idx + i < n, T.Cast("float32", A[by, tx_idx + i]), T.Cast("float32", 0))
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                local_buf[i] = local_buf[i] + local_buf[i - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                shared_buf[ty * T.int64(128) + tx * T.int64(4) + i] = local_buf[i]
                            for i in T.unroll(T.int64(5)):
                                for j in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                    if tx >= T.shift_left(T.int64(1), i):
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                            for i in T.unroll(T.int64(1), T.int64(4)):
                                for j in T.vectorized(T.int64(4)):
                                    if ty == T.int64(0):
                                        idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                        shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i * T.int64(128) - T.int64(1)]
                            for i in T.vectorized(T.int64(4)):
                                idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i
                                if bx * T.int64(512) + idx < n:
                                    Out[by, bx * T.int64(512) + idx] = shared_buf[idx]
                            if tx == T.int64(0) and ty == T.int64(0):
                                for i in T.vectorized(T.int64(4)):
                                    Tmp[by, bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds):
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (i + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    with T.block(""):
                        T.reads(Tmp[by, bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)):bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(512)])
                        T.writes(Tmp[by, T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx):T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + (T.max(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)) + T.int64(511), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx) + T.int64(1) - T.min(bx * T.int64(512) + i * ((n + T.int64(511)) // T.int64(512)), (i + T.int64(1)) * ((n + T.int64(511)) // T.int64(512)) + bx))])
                        local_buf = T.alloc_buffer((4,), scope="local")
                        shared_buf = T.alloc_buffer((T.int64(512),), scope="shared")
                        for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                            for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                                tx_idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + tx * T.int64(4)
                                for i_1 in T.vectorized(T.int64(4)):
                                    local_buf[i_1] = T.if_then_else(tx_idx + i_1 < cur_len, T.Cast("float32", Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + tx_idx + i_1]), T.Cast("float32", 0))
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    local_buf[i_1] = local_buf[i_1] + local_buf[i_1 - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    shared_buf[ty * T.int64(128) + tx * T.int64(4) + i_1] = local_buf[i_1]
                                for i_1 in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i_1):
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[idx - T.shift_left(T.int64(1), i_1) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i_1 in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i_1 * T.int64(128) + tx * T.int64(4)
                                            shared_buf[idx + j] = shared_buf[idx + j] + shared_buf[i_1 * T.int64(128) - T.int64(1)]
                                for i_1 in T.vectorized(T.int64(4)):
                                    idx: T.int64 = ty * T.int64(128) + tx * T.int64(4) + i_1
                                    if bx * T.int64(512) + idx < cur_len:
                                        Tmp[by, i * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx * T.int64(512) + idx] = shared_buf[idx]
                                if tx == T.int64(0) and ty == T.int64(0):
                                    for i_1 in T.vectorized(T.int64(4)):
                                        Tmp[by, (i + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx] = shared_buf[T.int64(511)]
        for i in range(total_rounds - T.int64(1)):
            real_idx: T.int64 = total_rounds - T.int64(1) - i - T.int64(1)
            cur_len: T.int64 = (n + T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1))) - T.int64(1)) // T.shift_left(T.int64(1), T.int64(9) * (real_idx + T.int64(1)))
            for by in T.thread_binding(m, thread="blockIdx.y"):
                for bx in T.thread_binding((cur_len + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                    for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                        for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                            for i_1 in range(T.int64(4)):
                                idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i_1 * T.int64(32) + tx
                                if idx < cur_len:
                                    Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] = Tmp[by, real_idx * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + idx] + T.if_then_else(bx > T.int64(0), Tmp[by, (real_idx + T.int64(1)) * ((n + T.int64(512) - T.int64(1)) // T.int64(512)) + bx - T.int64(1)], T.float32(0.0))
        for by in T.thread_binding(m, thread="blockIdx.y"):
            for bx in T.thread_binding((n + T.int64(511)) // T.int64(512), thread="blockIdx.x"):
                for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                    for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        for i in range(T.int64(4)):
                            idx: T.int64 = bx * T.int64(512) + ty * T.int64(128) + i * T.int64(32) + tx
                            if idx < n:
                                Out[by, idx] = Out[by, idx] + T.if_then_else(bx > T.int64(0), Tmp[by, bx - T.int64(1)], T.float32(0.0))

    @T.prim_func(private=True)
    def index(var_rms_norm48: T.handle, index: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm48 = T.match_buffer(var_rms_norm48, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for i, _, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("index"):
                v_i, v__, v_k = T.axis.remap("SSS", [i, _, k])
                T.reads(rms_norm48[v_i, seq_len - T.int64(1), v_k])
                T.writes(index[v_i, v__, v_k])
                index[v_i, v__, v_k] = rms_norm48[v_i, seq_len - T.int64(1), v_k]

    @T.prim_func
    def merge_state_inplace(v: T.handle, s: T.handle, v_other: T.handle, s_other: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        N, H, D = T.int32(is_size_var=True), T.int32(is_size_var=True), T.int32(is_size_var=True)
        V = T.match_buffer(v, (N, H, D), "float16")
        S = T.match_buffer(s, (N, H))
        V_other = T.match_buffer(v_other, (N, H, D), "float16")
        S_other = T.match_buffer(s_other, (N, H))
        # with T.block("root"):
        for bx in T.thread_binding(N, thread="blockIdx.x"):
            for by in T.thread_binding(1, thread="blockIdx.y"):
                for ty in T.thread_binding(16, thread="threadIdx.y"):
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("merge"):
                            T.reads(S[bx, ty + by * 16], S_other[bx, ty + by * 16], V[bx, ty + by * 16, tx * 4:tx * 4 + 4], V_other[bx, ty + by * 16, tx * 4:tx * 4 + 4])
                            T.writes(V[bx, ty + by * 16, tx * 4:tx * 4 + 4], S[bx, ty + by * 16])
                            s_val = T.alloc_buffer((1,), scope="local")
                            s_other_val = T.alloc_buffer((1,), scope="local")
                            s_max = T.alloc_buffer((1,), scope="local")
                            scale = T.alloc_buffer((1,), scope="local")
                            other_scale = T.alloc_buffer((1,), scope="local")
                            v_vec = T.alloc_buffer((4,), "float16", scope="local")
                            v_other_vec = T.alloc_buffer((4,), "float16", scope="local")
                            s_val[0] = S[bx, ty + by * 16]
                            s_other_val[0] = S_other[bx, ty + by * 16]
                            s_max[0] = T.max(s_val[0], s_other_val[0])
                            s_val[0] = T.exp2(s_val[0] - s_max[0])
                            s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
                            scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
                            other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0])
                            for vec in T.vectorized(4):
                                v_vec[vec] = V[bx, ty + by * 16, tx * 4 + vec]
                            for vec in T.vectorized(4):
                                v_other_vec[vec] = V_other[bx, ty + by * 16, tx * 4 + vec]
                            for vec in range(4):
                                v_vec[vec] = T.Cast("float16", T.Cast("float32", v_vec[vec]) * scale[0] + T.Cast("float32", v_other_vec[vec]) * other_scale[0])
                            for vec in T.vectorized(4):
                                V[bx, ty + by * 16, tx * 4 + vec] = v_vec[vec]
                            S[bx, ty + by * 16] = T.log2(s_val[0] + s_other_val[0]) + s_max[0]

    @T.prim_func(private=True)
    def moe_dequantize_gemv(x: T.Buffer((1, 2048), "float16"), w: T.Buffer((60, 2816, 256), "uint32"), scale: T.Buffer((60, 2816, 64), "float16"), indptr: T.Buffer((1, 4), "int32"), o: T.Buffer((4, 2816), "float16")):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for expert_id in T.thread_binding(4, thread="blockIdx.y"):
            with T.block("gemv_o"):
                e = T.axis.spatial(4, expert_id)
                T.reads(w[indptr[0, e], 0:2816, 0:256], indptr[0, e], scale[indptr[0, e], 0:2816, 0:64], x[0, 0:2048])
                T.writes(o[e, 0:2816])
                y = T.alloc_buffer((2816, 2048), "float16")
                for i1, i2 in T.grid(2816, 2048):
                    with T.block("dequantize"):
                        i, j = T.axis.remap("SS", [i1, i2])
                        T.reads(w[indptr[0, e], i, j // 8], indptr[0, e], scale[indptr[0, e], i, j // 32])
                        T.writes(y[i, j])
                        y[i, j] = (T.Cast("float16", T.bitwise_and(T.shift_right(w[indptr[0, e], i, j // 8], T.Cast("uint32", j % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * scale[indptr[0, e], i, j // 32]
                for i1, i2 in T.grid(2816, 2048):
                    with T.block("gemv"):
                        i, j = T.axis.remap("SR", [i1, i2])
                        T.reads(x[0, j], y[i, j])
                        T.writes(o[e, i])
                        with T.init():
                            o[e, i] = T.float16(0.0)
                        o[e, i] = o[e, i] + x[0, j] * y[i, j]

    @T.prim_func(private=True)
    def moe_dequantize_gemv1(x: T.Buffer((4, 1408), "float16"), w: T.Buffer((60, 2048, 176), "uint32"), scale: T.Buffer((60, 2048, 44), "float16"), indptr: T.Buffer((1, 4), "int32"), o: T.Buffer((4, 2048), "float16")):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for expert_id in T.thread_binding(4, thread="blockIdx.y"):
            with T.block("gemv_o"):
                e = T.axis.spatial(4, expert_id)
                T.reads(w[indptr[0, e], 0:2048, 0:176], indptr[0, e], scale[indptr[0, e], 0:2048, 0:44], x[e, 0:1408])
                T.writes(o[e, 0:2048])
                y = T.alloc_buffer((2048, 1408), "float16")
                for i1, i2 in T.grid(2048, 1408):
                    with T.block("dequantize"):
                        i, j = T.axis.remap("SS", [i1, i2])
                        T.reads(w[indptr[0, e], i, j // 8], indptr[0, e], scale[indptr[0, e], i, j // 32])
                        T.writes(y[i, j])
                        y[i, j] = (T.Cast("float16", T.bitwise_and(T.shift_right(w[indptr[0, e], i, j // 8], T.Cast("uint32", j % 8 * 4)), T.uint32(15))) - T.float16(7.0)) * scale[indptr[0, e], i, j // 32]
                for i1, i2 in T.grid(2048, 1408):
                    with T.block("gemv"):
                        i, j = T.axis.remap("SR", [i1, i2])
                        T.reads(x[e, j], y[i, j])
                        T.writes(o[e, i])
                        with T.init():
                            o[e, i] = T.float16(0.0)
                        o[e, i] = o[e, i] + x[e, j] * y[i, j]

    @T.prim_func
    def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle):
        T.func_attr({"tir.is_scheduled": 1})
        n, vocab_size = T.int64(), T.int64()
        prob = T.match_buffer(var_prob, (n, vocab_size))
        batch_size = T.int64()
        uniform_samples = T.match_buffer(var_uniform_samples, (batch_size, 1))
        row_indices = T.match_buffer(var_row_indices, (batch_size, 1), "int32")
        token_ids = T.match_buffer(var_sampled_token_ids, (batch_size, 1), "int32")
        # with T.block("root"):
        aggregate = T.alloc_buffer((), scope="local")
        sample_id_local = T.alloc_buffer((), "int32", scope="local")
        step_iter = T.alloc_buffer((), "int32", scope="local")
        for bx in T.thread_binding(batch_size, thread="blockIdx.x"):
            row_idx: T.int32 = row_indices[bx, 0]
            for ty in T.thread_binding(T.int64(4), thread="threadIdx.y"):
                for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                    u: T.float32 = uniform_samples[bx, 0]
                    aggregate[()] = T.Cast("float32", 0)
                    step_iter[()] = 0
                    while T.tvm_thread_invariant((step_iter[()] == 0 or aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
                        with T.block(""):
                            T.reads(step_iter[()], prob[row_idx, T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + T.int64(4)], aggregate[()])
                            T.writes(sample_id_local[()], aggregate[()])
                            prob_gt_threshold = T.alloc_buffer((T.int64(4),), scope="local")
                            cumsum = T.alloc_buffer((T.int64(512),), scope="shared")
                            greater_than_u = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            mask = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            valid = T.alloc_buffer((T.int64(4),), "bool", scope="local")
                            indices = T.alloc_buffer((T.int64(4),), "int32", scope="local")
                            step_aggregate = T.alloc_buffer((), scope="local")
                            for v in T.unroll(T.int64(4)):
                                idx: T.int64 = T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v
                                prob_local: T.float32 = T.if_then_else(idx < vocab_size, prob[row_idx, idx], T.Cast("float32", 0))
                                prob_gt_threshold[v] = T.if_then_else(prob_local > T.float32(0.0), prob_local, T.Cast("float32", 0))
                                valid[v] = prob_local > T.float32(0.0) and idx < vocab_size
                            with T.block(""):
                                T.reads(prob_gt_threshold[T.int64(0):T.int64(4)])
                                T.writes(step_aggregate[()])
                                local_sum = T.alloc_buffer((), scope="local")
                                shared_buf = T.alloc_buffer((T.int64(128),), scope="shared")
                                idx: T.int64 = ty * T.int64(32) + tx
                                local_sum[()] = T.Cast("float32", 0)
                                for i in T.unroll(T.int64(4)):
                                    local_sum[()] = local_sum[()] + prob_gt_threshold[i]
                                shared_buf[idx] = local_sum[()]
                                for i in T.unroll(T.int64(7)):
                                    if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                        shared_buf[idx] = shared_buf[idx] + shared_buf[idx + T.shift_left(T.int64(1), i)]
                                step_aggregate[()] = shared_buf[0]
                            if T.tvm_thread_invariant(aggregate[()] + step_aggregate[()] >= u - T.float32(9.9999999999999995e-07)):
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    prob_gt_threshold[i] = prob_gt_threshold[i] + prob_gt_threshold[i - T.int64(1)]
                                for i in T.vectorized(T.int64(4)):
                                    cumsum[ty * T.int64(128) + tx * T.int64(4) + i] = prob_gt_threshold[i]
                                for i in T.unroll(T.int64(5)):
                                    for j in T.vectorized(T.int64(4)):
                                        idx: T.int64 = ty * T.int64(128) + tx * T.int64(4)
                                        if tx >= T.shift_left(T.int64(1), i):
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[idx - T.shift_left(T.int64(1), i) * T.int64(4) + T.int64(4) - T.int64(1)]
                                for i in T.unroll(T.int64(1), T.int64(4)):
                                    for j in T.vectorized(T.int64(4)):
                                        if ty == T.int64(0):
                                            idx: T.int64 = i * T.int64(128) + tx * T.int64(4)
                                            cumsum[idx + j] = cumsum[idx + j] + cumsum[i * T.int64(128) - T.int64(1)]
                                for v in T.unroll(T.int64(4)):
                                    greater_than_u[v] = cumsum[ty * T.int64(128) + tx * T.int64(4) + v] + aggregate[()] >= u - T.float32(9.9999999999999995e-07)
                                with T.block(""):
                                    T.reads(greater_than_u[T.int64(0):T.int64(4)])
                                    T.writes(mask[T.int64(0):T.int64(4)])
                                    shared_buf = T.alloc_buffer((T.int64(128),), "bool", scope="shared")
                                    tx_idx: T.int64 = ty * T.int64(32) + tx
                                    shared_buf[tx_idx] = greater_than_u[T.int64(3)]
                                    mask[0] = T.if_then_else(tx_idx != T.int64(0), T.Cast("int8", greater_than_u[0]) != T.Cast("int8", shared_buf[tx_idx - T.int64(1)]), greater_than_u[0])
                                    for i in T.unroll(T.int64(1), T.int64(4)):
                                        mask[i] = T.Cast("int8", greater_than_u[i]) != T.Cast("int8", greater_than_u[i - T.int64(1)])
                                for v in T.unroll(T.int64(4)):
                                    mask[v] = mask[v] and valid[v]
                                    indices[v] = T.Cast("int32", T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * T.int64(4) + v)
                                with T.block(""):
                                    T.reads(mask[T.int64(0):T.int64(4)], indices[T.int64(0):T.int64(4)])
                                    T.writes(sample_id_local[()])
                                    local_sum = T.alloc_buffer((), "int32", scope="local")
                                    shared_buf = T.alloc_buffer((T.int64(128),), "int32", scope="shared")
                                    idx: T.int64 = ty * T.int64(32) + tx
                                    local_sum[()] = T.Cast("int32", vocab_size - T.int64(1))
                                    for i in T.unroll(T.int64(4)):
                                        if mask[i]:
                                            local_sum[()] = T.min(local_sum[()], indices[i])
                                    shared_buf[idx] = local_sum[()]
                                    for i in T.unroll(T.int64(7)):
                                        if idx % T.shift_left(T.int64(1), i + T.int64(1)) == T.int64(0):
                                            shared_buf[idx] = T.min(shared_buf[idx], shared_buf[idx + T.shift_left(T.int64(1), i)])
                                    sample_id_local[()] = shared_buf[0]
                            aggregate[()] = aggregate[()] + step_aggregate[()]
                        step_iter[()] = step_iter[()] + 1
                    if tx == T.int64(0) and ty == T.int64(0):
                        token_ids[bx, 0] = sample_id_local[()]

    @T.prim_func(private=True)
    def reshape(var_add288: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        add288 = T.match_buffer(var_add288, (batch_size, T.int64(1), T.int64(6144)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(48), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(48), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(add288[((v_ax2 * T.int64(128) + v_ax3) // T.int64(6144) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(6144)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = add288[((v_ax2 * T.int64(128) + v_ax3) // T.int64(6144) + v_ax0 + v_ax1) % batch_size, T.int64(0), (v_ax2 * T.int64(128) + v_ax3) % T.int64(6144)]

    @T.prim_func(private=True)
    def reshape1(var_reshape624: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape624 = T.match_buffer(var_reshape624, (batch_size, T.int64(1), T.int64(48), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(48), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(48), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape624[((v_ax2 // T.int64(128) + v_ax1) // T.int64(48) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(48), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape624[((v_ax2 // T.int64(128) + v_ax1) // T.int64(48) + v_ax0) % batch_size, T.int64(0), (v_ax2 // T.int64(128) + v_ax1) % T.int64(48), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape10(var_reshape408: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape408 = T.match_buffer(var_reshape408, (T.int64(1), seq_len, T.int64(48), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(48), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(seq_len, T.int64(48), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape408[T.int64(0), ((v_ax2 // T.int64(128) + v_ax1) // T.int64(48) + v_ax0) % seq_len, (v_ax2 // T.int64(128) + v_ax1) % T.int64(48), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape408[T.int64(0), ((v_ax2 // T.int64(128) + v_ax1) // T.int64(48) + v_ax0) % seq_len, (v_ax2 // T.int64(128) + v_ax1) % T.int64(48), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape11(var_lv485: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        lv485 = T.match_buffer(var_lv485, (seq_len, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(16), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(16), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv485[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv485[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape12(var_reshape410: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        reshape410 = T.match_buffer(var_reshape410, (T.int64(1), seq_len, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape410[T.int64(0), (v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape410[T.int64(0), (v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape13(var_rms_norm99: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm99 = T.match_buffer(var_rms_norm99, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1 in T.grid(seq_len, T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(rms_norm99[T.int64(0), (v_ax1 // T.int64(2048) + v_ax0) % seq_len, v_ax1 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = rms_norm99[T.int64(0), (v_ax1 // T.int64(2048) + v_ax0) % seq_len, v_ax1 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape14(var_add194: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        add194 = T.match_buffer(var_add194, (seq_len, T.int64(2048)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(add194[(v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = add194[(v_ax2 // T.int64(2048) + v_ax0 * seq_len + v_ax1) % seq_len, v_ax2 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape2(var_lv774: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv774 = T.match_buffer(var_lv774, (batch_size, T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(16), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(batch_size, T.int64(1), T.int64(16), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(lv774[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = lv774[((v_ax3 // T.int64(128) + v_ax2) // T.int64(16) + v_ax0 + v_ax1) % batch_size, (v_ax3 // T.int64(128) + v_ax2) % T.int64(16), v_ax3 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape21(lv296: T.Buffer((T.int64(4), T.int64(2048)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(4), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(4), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv296[(v_ax2 // T.int64(2048) + v_ax1) % T.int64(4), v_ax2 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = lv296[(v_ax2 // T.int64(2048) + v_ax1) % T.int64(4), v_ax2 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape22(add98: T.Buffer((T.int64(1), T.int64(2048)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(add98[T.int64(0), v_ax2 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = add98[T.int64(0), v_ax2 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape3(var_reshape626: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape626 = T.match_buffer(var_reshape626, (batch_size, T.int64(1), T.int64(16), T.int64(128)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(reshape626[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = reshape626[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, T.int64(0), v_ax2 % T.int64(2048) // T.int64(128), v_ax2 % T.int64(128)]

    @T.prim_func(private=True)
    def reshape4(var_rms_norm148: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        rms_norm148 = T.match_buffer(var_rms_norm148, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch_size, T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(rms_norm148[(v_ax1 // T.int64(2048) + v_ax0) % batch_size, T.int64(0), v_ax1 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1])
                T_reshape[v_ax0, v_ax1] = rms_norm148[(v_ax1 // T.int64(2048) + v_ax0) % batch_size, T.int64(0), v_ax1 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape5(var_permute_dims486: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        permute_dims486 = T.match_buffer(var_permute_dims486, (T.int64(60), batch_size), "int32")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size * T.int64(60),), "int32")
        # with T.block("root"):
        for ax0 in range(batch_size * T.int64(60)):
            with T.block("T_reshape"):
                v_ax0 = T.axis.spatial(batch_size * T.int64(60), ax0)
                T.reads(permute_dims486[v_ax0 // batch_size % T.int64(60), v_ax0 % batch_size])
                T.writes(T_reshape[v_ax0])
                T_reshape[v_ax0] = permute_dims486[v_ax0 // batch_size % T.int64(60), v_ax0 % batch_size]

    @T.prim_func(private=True)
    def reshape6(var_top4_softmax_072: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        top4_softmax_072 = T.match_buffer(var_top4_softmax_072, (batch_size, T.int64(4)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(4), T.int64(1)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(4), T.int64(1)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(top4_softmax_072[((v_ax1 + v_ax2) // T.int64(4) + v_ax0) % batch_size, (v_ax1 + v_ax2) % T.int64(4)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = top4_softmax_072[((v_ax1 + v_ax2) // T.int64(4) + v_ax0) % batch_size, (v_ax1 + v_ax2) % T.int64(4)]

    @T.prim_func(private=True)
    def reshape7(var_lv782: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        lv782 = T.match_buffer(var_lv782, (batch_size * T.int64(4), T.int64(2048)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(4), T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(4), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(lv782[(v_ax0 * T.int64(4) + v_ax2 // T.int64(2048) + v_ax1) % (batch_size * T.int64(4)), v_ax2 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = lv782[(v_ax0 * T.int64(4) + v_ax2 // T.int64(2048) + v_ax1) % (batch_size * T.int64(4)), v_ax2 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape8(var_add290: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        add290 = T.match_buffer(var_add290, (batch_size, T.int64(2048)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(add290[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, v_ax2 % T.int64(2048)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
                T_reshape[v_ax0, v_ax1, v_ax2] = add290[(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % batch_size, v_ax2 % T.int64(2048)]

    @T.prim_func(private=True)
    def reshape9(var_add192: T.handle, var_T_reshape: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        add192 = T.match_buffer(var_add192, (T.int64(1), seq_len, T.int64(6144)), "float16")
        T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), seq_len, T.int64(48), T.int64(128)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), seq_len, T.int64(48), T.int64(128)):
            with T.block("T_reshape"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(add192[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(6144) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(128) + v_ax3) % T.int64(6144)])
                T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
                T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = add192[T.int64(0), ((v_ax2 * T.int64(128) + v_ax3) // T.int64(6144) + v_ax0 * seq_len + v_ax1) % seq_len, (v_ax2 * T.int64(128) + v_ax3) % T.int64(6144)]

    @T.prim_func(private=True)
    def rms_norm(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight4: T.Buffer((T.int64(2048),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (batch_size, T.int64(1), T.int64(2048)), "float16")
        T_cast = T.match_buffer(var_T_cast, (batch_size, T.int64(1), T.int64(2048)), "float16")
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)))
        T_multiply = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)))
        T_multiply_red = T.alloc_buffer((batch_size, T.int64(1)))
        rsqrt = T.alloc_buffer((batch_size, T.int64(1)))
        T_cast_2 = T.alloc_buffer((T.int64(2048),))
        T_rms_norm = T.alloc_buffer((batch_size, T.int64(1), T.int64(2048)))
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embeds[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(batch_size, T.int64(1)):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07))
        for ax0 in range(T.int64(2048)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(model_layers_0_input_layernorm_weight4[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", model_layers_0_input_layernorm_weight4[v_ax0])
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(batch_size, T.int64(1), T.int64(2048)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func(private=True)
    def rms_norm1(var_input_embeds: T.handle, model_layers_0_input_layernorm_weight3: T.Buffer((T.int64(2048),), "float16"), var_T_cast: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        input_embeds = T.match_buffer(var_input_embeds, (T.int64(1), seq_len, T.int64(2048)), "float16")
        T_cast = T.match_buffer(var_T_cast, (T.int64(1), seq_len, T.int64(2048)), "float16")
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)))
        T_multiply = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)))
        T_multiply_red = T.alloc_buffer((T.int64(1), seq_len))
        rsqrt = T.alloc_buffer((T.int64(1), seq_len))
        T_cast_2 = T.alloc_buffer((T.int64(2048),))
        T_rms_norm = T.alloc_buffer((T.int64(1), seq_len, T.int64(2048)))
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embeds[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embeds[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(T.int64(1), seq_len):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07))
        for ax0 in range(T.int64(2048)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(model_layers_0_input_layernorm_weight3[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", model_layers_0_input_layernorm_weight3[v_ax0])
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(2048)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func(private=True)
    def rms_norm2(input_embed: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16"), model_layers_0_input_layernorm_weight2: T.Buffer((T.int64(2048),), "float16"), T_cast: T.Buffer((T.int64(1), T.int64(1), T.int64(2048)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_cast_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)))
        T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)))
        T_multiply_red = T.alloc_buffer((T.int64(1), T.int64(1)))
        rsqrt = T.alloc_buffer((T.int64(1), T.int64(1)))
        T_cast_2 = T.alloc_buffer((T.int64(2048),))
        T_rms_norm = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2048)))
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_cast"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(input_embed[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
                T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32", input_embed[v_ax0, v_ax1, v_ax2])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
                T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, k2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_multiply_red"):
                v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
                T.writes(T_multiply_red[v_ax0, v_ax1])
                with T.init():
                    T_multiply_red[v_ax0, v_ax1] = T.float32(0.0)
                T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
            with T.block("rsqrt"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(T_multiply_red[v_ax0, v_ax1])
                T.writes(rsqrt[v_ax0, v_ax1])
                rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1] * T.float32(0.00048828125) + T.float32(9.9999999999999995e-07))
        for ax0 in range(T.int64(2048)):
            with T.block("T_cast_1"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(model_layers_0_input_layernorm_weight2[v_ax0])
                T.writes(T_cast_2[v_ax0])
                T_cast_2[v_ax0] = T.Cast("float32", model_layers_0_input_layernorm_weight2[v_ax0])
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_rms_norm"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax2])
                T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2048)):
            with T.block("T_cast_2"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
                T.writes(T_cast[v_ax0, v_ax1, v_ax2])
                T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16", T_rms_norm[v_ax0, v_ax1, v_ax2])

    @T.prim_func
    def sampler_take_probs_tir(var_unsorted_probs: T.handle, var_sorted_indices: T.handle, var_sample_indices: T.handle, var_sampling_results: T.handle, var_top_prob_offsets: T.handle, var_sampled_values: T.handle, var_top_prob_probs: T.handle, var_top_prob_indices: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32})})
        batch_size, vocab_size = T.int32(is_size_var=True), T.int32(is_size_var=True)
        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size))
        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), "int32")
        num_samples = T.int32(is_size_var=True)
        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), "int32")
        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), "int32")
        num_positions = T.int32(is_size_var=True)
        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), "int32")
        sampled_values = T.match_buffer(var_sampled_values, (num_samples,))
        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,))
        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), "int32")
        # with T.block("root"):
        for i in range(num_positions + num_samples):
            with T.block("block"):
                vi = T.axis.spatial(num_positions + num_samples, i)
                T.reads(top_prob_offsets[vi], sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], unsorted_probs[T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]):T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + (T.max(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions]) + 1 - T.min(top_prob_offsets[vi] // vocab_size, sample_indices[vi - num_positions])), T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]):T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + (T.max(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]) + 1 - T.min(sorted_indices[top_prob_offsets[vi] // vocab_size, top_prob_offsets[vi] % vocab_size], sampling_results[vi - num_positions]))], sample_indices[vi - num_positions], sampling_results[vi - num_positions])
                T.writes(top_prob_indices[vi], top_prob_probs[vi], sampled_values[vi - num_positions])
                if vi < num_positions:
                    row: T.int32 = top_prob_offsets[vi] // vocab_size
                    col: T.int32 = top_prob_offsets[vi] % vocab_size
                    top_prob_indices[vi] = sorted_indices[row, col]
                    top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]]
                else:
                    vj: T.int32 = vi - num_positions
                    sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]

    @T.prim_func(private=True)
    def scatter_output(var_x: T.handle, var_indices: T.handle, var_out: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        indices_len = T.int64()
        x = T.match_buffer(var_x, (indices_len, 2048), "float16")
        indices = T.match_buffer(var_indices, (indices_len,), "int32")
        out = T.match_buffer(var_out, (indices_len, 2048), "float16")
        # with T.block("root"):
        for i, j in T.grid(indices_len, 2048):
            with T.block("scatter"):
                vi = T.axis.spatial(indices_len, i)
                vj = T.axis.spatial(2048, j)
                T.reads(x[vi, vj], indices[vi])
                T.writes(out[indices[vi], vj])
                out[indices[vi], vj] = x[vi, vj]

    @T.prim_func
    def scatter_probs(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, n = T.int32(is_size_var=True), T.int32(is_size_var=True)
        src = T.match_buffer(var_src, (batch_size, n))
        indices = T.match_buffer(var_indices, (batch_size,), "int32")
        m = T.int32(is_size_var=True)
        dst = T.match_buffer(var_dst, (m, n))
        # with T.block("root"):
        for b, j in T.grid(batch_size, n):
            with T.block("scatter_2d"):
                vb, vj = T.axis.remap("SS", [b, j])
                T.reads(src[vb, vj], indices[vb])
                T.writes(dst[indices[vb], vj])
                dst[indices[vb], vj] = src[vb, vj]

    @T.prim_func
    def softmax_with_chunked_sum(var_A: T.handle, var_temperature: T.handle, var_chunked_sum: T.handle, var_chunked_max: T.handle, var_softmax: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
        A = T.match_buffer(var_A, (batch_size, vocab_size))
        temperature = T.match_buffer(var_temperature, (batch_size,))
        num_chunks = T.int64(is_size_var=True)
        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks))
        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks))
        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size))
        # with T.block("root"):
        temp_max_shared = T.alloc_buffer((batch_size,), scope="shared")
        temp_sum_shared = T.alloc_buffer((batch_size,), scope="shared")
        for l0_l1_fused in T.thread_binding(batch_size * num_chunks, thread="blockIdx.x"):
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("max"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(chunked_max[v0, v1])
                        T.writes(temp_max_shared[v0])
                        with T.init():
                            temp_max_shared[v0] = T.float32(-340282346638528859811704183484516925440.0)
                        temp_max_shared[v0] = T.max(temp_max_shared[v0], chunked_max[v0, v1])
            for ax0_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                for ax0_0 in T.serial((num_chunks + T.int64(31)) // T.int64(32), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                    with T.block("sum_exp"):
                        v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                        v1 = T.axis.reduce(num_chunks, ax0_0 * T.int64(32) + ax0_1)
                        T.where(ax0_0 * T.int64(32) + ax0_1 < num_chunks)
                        T.reads(temperature[v0], chunked_sum[v0, v1], chunked_max[v0, v1], temp_max_shared[v0])
                        T.writes(temp_sum_shared[v0])
                        with T.init():
                            temp_sum_shared[v0] = T.float32(0.0)
                        temp_sum_shared[v0] = temp_sum_shared[v0] + T.Select(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max_shared[v0]), T.Cast("float32", chunked_max[v0, v1] == temp_max_shared[v0]) * chunked_sum[v0, v1])
            for l2_0 in T.serial(T.int64(4), annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1}):
                for l2_1 in T.thread_binding(T.int64(32), thread="threadIdx.y"):
                    for l2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
                        with T.block("log_pad"):
                            v0 = T.axis.spatial(batch_size, l0_l1_fused % (num_chunks * batch_size) // num_chunks)
                            v1 = T.axis.spatial(num_chunks, l0_l1_fused % num_chunks)
                            v2 = T.axis.spatial(T.int64(4096), l2_0 * T.int64(1024) + l2_1 * T.int64(32) + l2_2)
                            T.reads(temperature[v0], A[v0, v1 * T.int64(4096) + v2], temp_sum_shared[v0], temp_max_shared[v0])
                            T.writes(softmax[v0, v1 * T.int64(4096) + v2])
                            if v1 * T.int64(4096) + v2 < vocab_size:
                                softmax[v0, v1 * T.int64(4096) + v2] = T.if_then_else(temperature[v0] > T.float32(1.0000000000000001e-05), T.exp(A[v0, v1 * T.int64(4096) + v2] / temperature[v0] - (T.log(temp_sum_shared[v0]) + temp_max_shared[v0])), T.Cast("float32", A[v0, v1 * T.int64(4096) + v2] == temp_max_shared[v0]) / temp_sum_shared[v0])

    @T.prim_func(private=True)
    def take(var_reshape628: T.handle, var_get_indices_148: T.handle, var_T_take: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        batch_size = T.int64()
        reshape628 = T.match_buffer(var_reshape628, (batch_size, T.int64(2048)), "float16")
        get_indices_148 = T.match_buffer(var_get_indices_148, (batch_size * T.int64(4),), "int32")
        T_take = T.match_buffer(var_T_take, (batch_size * T.int64(4), T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1 in T.grid(batch_size * T.int64(4), T.int64(2048)):
            with T.block("T_take"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(reshape628[get_indices_148[v_ax0], v_ax1], get_indices_148[v_ax0])
                T.writes(T_take[v_ax0, v_ax1])
                T_take[v_ax0, v_ax1] = reshape628[get_indices_148[v_ax0], v_ax1]

    @T.prim_func(private=True)
    def take1(var_rms_norm146: T.handle, var_logit_positions: T.handle, var_T_take: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        seq_len = T.int64()
        rms_norm146 = T.match_buffer(var_rms_norm146, (T.int64(1), seq_len, T.int64(2048)), "float16")
        batch_size = T.int64()
        logit_positions = T.match_buffer(var_logit_positions, (batch_size,), "int32")
        T_take = T.match_buffer(var_T_take, (T.int64(1), batch_size, T.int64(2048)), "float16")
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), batch_size, T.int64(2048)):
            with T.block("T_take"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(rms_norm146[v_ax0, logit_positions[v_ax1], v_ax2], logit_positions[v_ax1])
                T.writes(T_take[v_ax0, v_ax1, v_ax2])
                T_take[v_ax0, v_ax1, v_ax2] = rms_norm146[v_ax0, logit_positions[v_ax1], v_ax2]

    @T.prim_func(private=True)
    def take_sorted_probs(var_probs: T.handle, var_lv1: T.handle, var_take_sorted_probs: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        batch_size, vocab_size = T.int64(), T.int64()
        probs = T.match_buffer(var_probs, (batch_size, vocab_size))
        lv1 = T.match_buffer(var_lv1, (batch_size, vocab_size), "int32")
        batch_size_1, vocab_size_1 = T.int64(is_size_var=True), T.int64(is_size_var=True)
        take_sorted_probs = T.match_buffer(var_take_sorted_probs, (batch_size_1, vocab_size_1))
        # with T.block("root"):
        for i, j in T.grid(batch_size_1, vocab_size_1):
            with T.block("take_sorted_probs"):
                v_i, v_j = T.axis.remap("SS", [i, j])
                T.reads(probs[v_i, lv1[v_i, v_j]], lv1[v_i, v_j])
                T.writes(take_sorted_probs[v_i, v_j])
                take_sorted_probs[v_i, v_j] = probs[v_i, lv1[v_i, v_j]]

    @T.prim_func
    def tir_kv_cache_debug_get_kv(var_pages: T.handle, var_position_map: T.handle, var_k_data: T.handle, var_v_data: T.handle, layer_id: T.int64):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        num_pages, page_size = T.int64(), T.int64(is_size_var=True)
        pages = T.match_buffer(var_pages, (num_pages, 2, 16, page_size, 128), "float16", offset_factor=1)
        seqlen = T.int64(is_size_var=True)
        position_map = T.match_buffer(var_position_map, (seqlen,), "int32", offset_factor=1)
        k_data = T.match_buffer(var_k_data, (24, seqlen, 16, 128), "float16")
        v_data = T.match_buffer(var_v_data, (24, seqlen, 16, 128), "float16")
        # with T.block("root"):
        for p, h, d in T.grid(seqlen, 16, 128):
            with T.block("copy0"):
                vp, vh, vd = T.axis.remap("SSS", [p, h, d])
                T.reads(position_map[vp], pages[T.Cast("int64", position_map[vp]) // page_size, 0:2, vh, T.Cast("int64", position_map[vp]) % page_size, vd])
                T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
                position: T.int32 = position_map[vp]
                k_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 0, vh, T.Cast("int64", position) % page_size, vd]
                v_data[layer_id, vp, vh, vd] = pages[T.Cast("int64", position) // page_size, 1, vh, T.Cast("int64", position) % page_size, vd]

    @T.prim_func
    def tir_kv_cache_transpose_append(var_pages: T.handle, var_k_data: T.handle, var_v_data: T.handle, var_position_map: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.noalias": T.bool(True)})
        num_pages = T.int64()
        pages = T.match_buffer(var_pages, (num_pages, 2, 16, 16, 128), "float16", offset_factor=1)
        ntoken = T.int64(is_size_var=True)
        k_data = T.match_buffer(var_k_data, (ntoken, 16, 128), "float16")
        v_data = T.match_buffer(var_v_data, (ntoken, 16, 128), "float16")
        position_map = T.match_buffer(var_position_map, (ntoken,), "int32", offset_factor=1)
        # with T.block("root"):
        for global_pos, h, f in T.grid(ntoken, 16, 128):
            if position_map[global_pos] != -1:
                with T.block("k_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 0, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 0, vh, position % 16, vf] = k_data[vgpos, vh, vf]
                with T.block("v_transpose_append"):
                    vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
                    T.reads(position_map[vgpos], v_data[vgpos, vh, vf])
                    T.writes(pages[position_map[vgpos] // 16, 1, vh, position_map[vgpos] % 16, vf])
                    position: T.int32 = position_map[vgpos]
                    pages[position // 16, 1, vh, position % 16, vf] = v_data[vgpos, vh, vf]

    @T.prim_func(private=True)
    def top4_softmax(var_x: T.handle, var_out: T.handle, var_out_index: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        batch_size = T.int64()
        x = T.match_buffer(var_x, (batch_size, 60), "float16")
        out = T.match_buffer(var_out, (batch_size, 4), "float16")
        out_index = T.match_buffer(var_out_index, (batch_size, 4), "int32")
        # with T.block("root"):
        local_top_k = T.alloc_buffer((4,), "float16", scope="local")
        local_top_k_index = T.alloc_buffer((4,), "int32", scope="local")
        for io in T.thread_binding((batch_size + T.int64(1023)) // T.int64(1024), thread="blockIdx.x"):
            for ii in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("top_k"):
                    vi = T.axis.spatial(batch_size, io * T.int64(1024) + T.Cast("int64", ii))
                    T.where(io * T.int64(1024) + T.Cast("int64", ii) < batch_size)
                    T.reads(x[vi, 0:60], local_top_k[0:4], local_top_k_index[0:4])
                    T.writes(local_top_k[0:4], local_top_k_index[0:4], out[vi, 0:4], out_index[vi, 0:4])
                    with T.block("init"):
                        T.reads()
                        T.writes(local_top_k[0:4], local_top_k_index[0:4])
                        local_top_k[0] = T.float16(-65504.0)
                        local_top_k[1] = T.float16(-65504.0)
                        local_top_k[2] = T.float16(-65504.0)
                        local_top_k[3] = T.float16(-65504.0)
                        local_top_k_index[0] = 0
                        local_top_k_index[1] = 1
                        local_top_k_index[2] = 2
                        local_top_k_index[3] = 3
                    for k in range(60):
                        with T.block("update"):
                            vk = T.axis.spatial(60, k)
                            T.reads(x[vi, vk], local_top_k[0:4], local_top_k_index[0:3])
                            T.writes(local_top_k[0:4], local_top_k_index[0:4])
                            if x[vi, vk] > local_top_k[0]:
                                local_top_k[3] = local_top_k[2]
                                local_top_k_index[3] = local_top_k_index[2]
                                local_top_k[2] = local_top_k[1]
                                local_top_k_index[2] = local_top_k_index[1]
                                local_top_k[1] = local_top_k[0]
                                local_top_k_index[1] = local_top_k_index[0]
                                local_top_k[0] = x[vi, vk]
                                local_top_k_index[0] = vk
                            else:
                                if x[vi, vk] > local_top_k[1]:
                                    local_top_k[3] = local_top_k[2]
                                    local_top_k_index[3] = local_top_k_index[2]
                                    local_top_k[2] = local_top_k[1]
                                    local_top_k_index[2] = local_top_k_index[1]
                                    local_top_k[1] = x[vi, vk]
                                    local_top_k_index[1] = vk
                                else:
                                    if x[vi, vk] > local_top_k[2]:
                                        local_top_k[3] = local_top_k[2]
                                        local_top_k_index[3] = local_top_k_index[2]
                                        local_top_k[2] = x[vi, vk]
                                        local_top_k_index[2] = vk
                                    else:
                                        if x[vi, vk] > local_top_k[3]:
                                            local_top_k[3] = x[vi, vk]
                                            local_top_k_index[3] = vk
                    for j in T.unroll(4):
                        with T.block("output"):
                            vj = T.axis.spatial(4, j)
                            T.reads(local_top_k[vj], local_top_k_index[vj])
                            T.writes(out[vi, vj], out_index[vi, vj])
                            out[vi, vj] = local_top_k[vj]
                            out_index[vi, vj] = local_top_k_index[vj]

    @T.prim_func(private=True)
    def top_p_pivot_cutoff(var_prob: T.handle, var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        top_p_arr = T.match_buffer(var_top_p_arr, (B,))
        init_pivots = T.match_buffer(var_init_pivots, (B, 3))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        # with T.block("root"):
        pivot = T.alloc_buffer((3,), scope="local")
        top_p = T.alloc_buffer((1,), scope="local")
        L = T.alloc_buffer((1,), scope="shared")
        R_1 = T.alloc_buffer((1,), scope="shared")
        L_local = T.alloc_buffer((1,), scope="local")
        R_local = T.alloc_buffer((1,), scope="local")
        q = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((3,), scope="local")
        lmin_broadcast = T.alloc_buffer((1,), scope="shared")
        lmin_broadcast_local = T.alloc_buffer((1,), scope="local")
        lmin = T.alloc_buffer((3,), scope="local")
        cmin = T.alloc_buffer((3,), "int32", scope="local")
        total_sum = T.alloc_buffer((1,), scope="local")
        it = T.alloc_buffer((1,), "int32", scope="local")
        es_local = T.alloc_buffer((1,), "bool", scope="local")
        es = T.alloc_buffer((1,), "bool", scope="shared")
        find_pivot_local = T.alloc_buffer((1,), "bool", scope="local")
        find_pivot = T.alloc_buffer((1,), "bool", scope="shared")
        total_sum_reduce = T.alloc_buffer((1,), scope="local")
        lsum_reduce = T.alloc_buffer((1,), scope="local")
        lmin_reduce = T.alloc_buffer((1,), scope="local")
        cmin_reduce = T.alloc_buffer((1,), "int32", scope="local")
        for _bx in T.thread_binding(B, thread="blockIdx.x"):
            for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                with T.block("CTA"):
                    b, tx = T.axis.remap("SS", [_bx, _tx])
                    T.reads(top_p_arr[b], top_p[0], L[0], R_1[0], init_pivots[b, 0:3], L_local[0], R_local[0], find_pivot_local[0], it[0], es_local[0], prob[b, it[0] * 1024 + tx], total_sum[0], q[0], pivot[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lsum[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], lmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], cmin[T.min(0, it[0]):T.min(0, it[0]) + (T.max(2, it[0]) + 1 - T.min(0, it[0]))], total_sum_reduce[0], es[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], lsum_reduce[0], cmin_reduce[0], find_pivot[0])
                    T.writes(top_p[0], L[0], R_1[0], find_pivot[0], L_local[0], R_local[0], pivot[0:3], find_pivot_local[0], final_lsum[b], final_pivot[b], lsum[0:3], lmin[0:3], cmin[0:3], total_sum[0], it[0], es_local[0], q[0], total_sum_reduce[0], es[0], lsum_reduce[0], lmin_reduce[0], lmin_broadcast[0], lmin_broadcast_local[0], cmin_reduce[0])
                    top_p[0] = top_p_arr[b]
                    if tx == 0:
                        L[0] = T.float32(1.0) - top_p[0]
                        R_1[0] = T.float32(9.9999999999999995e-08)
                        find_pivot[0] = T.bool(False)
                    T.tvm_storage_sync("shared")
                    L_local[0] = L[0]
                    R_local[0] = R_1[0]
                    for i in T.unroll(3):
                        pivot[i] = init_pivots[b, i]
                    find_pivot_local[0] = T.bool(False)
                    if L_local[0] - R_local[0] <= T.float32(9.9999999999999995e-08):
                        if tx == 0:
                            final_lsum[b] = T.float32(1.0)
                            final_pivot[b] = T.float32(0.0)
                        find_pivot_local[0] = T.bool(True)
                    while T.tvm_thread_invariant(L_local[0] - R_local[0] > T.float32(9.9999999999999995e-08) and not find_pivot_local[0]):
                        T.tvm_storage_sync("shared")
                        for pidx in T.unroll(3):
                            lsum[pidx] = T.float32(0.0)
                            lmin[pidx] = T.float32(340282346638528859811704183484516925440.0)
                            cmin[pidx] = 0
                        total_sum[0] = T.float32(0.0)
                        it[0] = 0
                        es_local[0] = T.bool(False)
                        while it[0] < (N + 1024 - 1) // 1024 and not es_local[0]:
                            q[0] = T.if_then_else(it[0] * 1024 + tx < N, prob[b, it[0] * 1024 + tx], T.float32(0.0))
                            total_sum[0] = total_sum[0] + q[0]
                            for pidx in T.unroll(3):
                                if q[0] >= pivot[pidx]:
                                    lsum[pidx] = lsum[pidx] + q[0]
                                    if lmin[pidx] > q[0]:
                                        lmin[pidx] = q[0]
                                        cmin[pidx] = 1
                                    else:
                                        if lmin[pidx] == q[0]:
                                            cmin[pidx] = cmin[pidx] + 1
                            it[0] = it[0] + 1
                            if it[0] % 32 == 0:
                                with T.block("block_cross_thread"):
                                    T.reads(total_sum[0])
                                    T.writes(total_sum_reduce[0])
                                    T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                    T.tvm_thread_allreduce(T.uint32(1), total_sum[0], T.bool(True), total_sum_reduce[0], tx)
                                if tx == 0:
                                    es[0] = T.float32(1.0) - total_sum_reduce[0] < pivot[2]
                                T.tvm_storage_sync("shared")
                                es_local[0] = es[0]
                        T.tvm_storage_sync("shared")
                        for pidx in range(3):
                            with T.block("block_cross_thread"):
                                T.reads(lsum[pidx])
                                T.writes(lsum_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], T.bool(True), lsum_reduce[0], tx)
                            with T.block("block_cross_thread"):
                                T.reads(lmin[pidx])
                                T.writes(lmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0.0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], T.bool(True), lmin_reduce[0], tx)
                            if tx == 0:
                                lmin_broadcast[0] = lmin_reduce[0]
                            T.tvm_storage_sync("shared")
                            lmin_broadcast_local[0] = lmin_broadcast[0]
                            if lmin[pidx] > lmin_broadcast_local[0]:
                                cmin[pidx] = 0
                            if tx == 0:
                                lsum[pidx] = lsum_reduce[0]
                                lmin[pidx] = lmin_reduce[0]
                            with T.block("block_cross_thread"):
                                T.reads(cmin[pidx])
                                T.writes(cmin_reduce[0])
                                T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [0]), "reduce_scope", T.reinterpret("handle", T.uint64(0)))
                                T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], T.bool(True), cmin_reduce[0], tx)
                            if tx == 0:
                                cmin[pidx] = cmin_reduce[0]
                        T.tvm_storage_sync("shared")
                        if tx == 0:
                            it[0] = 0
                            while it[0] < 3 and not find_pivot_local[0]:
                                if lsum[it[0]] >= top_p[0] and top_p[0] > lsum[it[0]] - T.Cast("float32", cmin[it[0]]) * lmin[it[0]]:
                                    find_pivot[0] = T.bool(True)
                                    find_pivot_local[0] = T.bool(True)
                                    final_pivot[b] = pivot[it[0]]
                                    final_lsum[b] = lsum[it[0]]
                                else:
                                    if lsum[it[0]] - lmin[it[0]] * T.Cast("float32", cmin[it[0]]) >= top_p[0]:
                                        R_1[0] = pivot[it[0]]
                                        final_lsum[b] = lsum[it[0]]
                                    else:
                                        if lsum[it[0]] < top_p[0]:
                                            L[0] = pivot[it[0]]
                                it[0] = it[0] + 1
                        T.tvm_storage_sync("shared")
                        L_local[0] = L[0]
                        R_local[0] = R_1[0]
                        find_pivot_local[0] = find_pivot[0]
                        for pidx in T.unroll(3):
                            pivot[pidx] = L[0] - T.Cast("float32", pidx + 1) * (L_local[0] - R_local[0]) / T.float32(4.0)
                    if tx == 0:
                        if not find_pivot_local[0]:
                            final_pivot[b] = R_local[0]
                            if R_local[0] == T.float32(9.9999999999999995e-08):
                                final_lsum[b] = lsum[2]

    @T.prim_func(private=True)
    def top_p_renorm_after_cutoff(var_prob: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, var_renorm_prob: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
        B, N = T.int32(is_size_var=True), T.int32(is_size_var=True)
        prob = T.match_buffer(var_prob, (B, N))
        final_pivot = T.match_buffer(var_final_pivot, (B,))
        final_lsum = T.match_buffer(var_final_lsum, (B,))
        renorm_prob = T.match_buffer(var_renorm_prob, (B, N))
        # with T.block("root"):
        pivot = T.alloc_buffer((1,), scope="local")
        lsum = T.alloc_buffer((1,), scope="local")
        for _by in T.thread_binding(B, thread="blockIdx.y"):
            for _bx in T.thread_binding(511 // B + 1, thread="blockIdx.x"):
                for _tx in T.thread_binding(1024, thread="threadIdx.x"):
                    with T.block("CTA"):
                        by, bx, tx = T.axis.remap("SSS", [_by, _bx, _tx])
                        T.reads(final_pivot[by], final_lsum[by], prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)], pivot[0], lsum[0])
                        T.writes(pivot[0], lsum[0], renorm_prob[by, bx * 1024 + tx:bx * 1024 + tx + (((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024) - 1) * (511 // B + 1) * 1024 + 1)])
                        pivot[0] = final_pivot[by]
                        lsum[0] = final_lsum[by]
                        for i in range((511 // B * 1024 + N + 1023) // (511 // B * 1024 + 1024)):
                            if i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx < N:
                                renorm_prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] = T.if_then_else(prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] >= pivot[0], prob[by, i * ((512 + B - 1) // B) * 1024 + bx * 1024 + tx] / lsum[0], T.float32(0.0))

    @T.prim_func
    def tree_attn_paged_kv(_0: T.int32, var_q: T.handle, var_q_indptr: T.handle, var_pages: T.handle, var_page_indptr: T.handle, var_page_values: T.handle, var_length_info: T.handle, var_k_rope_pos_offset: T.handle, var_q_rope_position: T.handle, var_output: T.handle, var_lse: T.handle, rotary_mode: T.int32, rope_scale: T.float32, rope_theta: T.float32, attn_score_scaling_factor: T.float32, tree_order_indptr_handle: T.handle, tree_order_handle: T.handle):
        T.func_attr({"target": T.target({"arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "mcpu": "znver3", "mtriple": "x86_64-pc-linux-gnu", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "libs": ["thrust"], "max_num_threads": 1024, "max_shared_memory_per_block": 49152, "max_threads_per_block": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": 1})
        total_len = T.int32(is_size_var=True)
        q = T.match_buffer(var_q, (total_len, 16, 128), "float16")
        batch_size = T.int32(is_size_var=True)
        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", offset_factor=1)
        max_num_pages = T.int32(is_size_var=True)
        pages = T.match_buffer(var_pages, (max_num_pages, 2, 16, 16, 128), "float16")
        page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", offset_factor=1)
        nnz_pages = T.int32(is_size_var=True)
        page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", offset_factor=1)
        length_info = T.match_buffer(var_length_info, (batch_size,), "int32", offset_factor=1)
        k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", offset_factor=1)
        q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", offset_factor=1)
        output = T.match_buffer(var_output, (total_len, 16, 128), "float16")
        lse = T.match_buffer(var_lse, (total_len, 16))
        tree_order_indptr = T.match_buffer(tree_order_indptr_handle, (batch_size + 1,), "int32", offset_factor=1)
        total_tree_order_len = T.int32(is_size_var=True)
        tree_order = T.match_buffer(tree_order_handle, (total_tree_order_len, 2), "int32", offset_factor=1)
        # with T.block("root"):
        assert rotary_mode == 0, "Inline rotary mode is not supported in tree attention."
        for lbx in T.thread_binding(16, thread="blockIdx.x"):
            for lby in T.thread_binding(16, thread="blockIdx.y"):
                for lty in T.thread_binding(4, thread="threadIdx.y"):
                    for ltx in T.thread_binding(32, thread="threadIdx.x"):
                        with T.block("attn"):
                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, lty, ltx])
                            T.reads()
                            T.writes()
                            tile_id = T.alloc_buffer((1,), "int32", scope="local")
                            batch_idx = T.alloc_buffer((1,), "int32", scope="local")
                            batch_tiles = T.alloc_buffer((1,), "int32", scope="local")
                            batch_rows = T.alloc_buffer((1,), "int32", scope="local")
                            iterator = T.alloc_buffer((1,), "int32", scope="local")
                            kv_chunk_len = T.alloc_buffer((1,), "int32", scope="local")
                            Q_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            K_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            V_smem = T.alloc_buffer((32, 128), "float16", scope="shared")
                            S_smem = T.alloc_buffer((32, 32), scope="shared")
                            S_local = T.alloc_buffer((32, 32), scope="local")
                            O_local = T.alloc_buffer((32, 128), scope="local")
                            m_smem = T.alloc_buffer((32,), scope="shared")
                            m_prev_smem = T.alloc_buffer((32,), scope="shared")
                            d_smem = T.alloc_buffer((32,), scope="shared")
                            m_new = T.alloc_buffer((1,), scope="local")
                            m_prev = T.alloc_buffer((1,), scope="local")
                            d_new = T.alloc_buffer((1,), scope="local")
                            tile_id[0] = bx
                            batch_idx[0] = 0
                            batch_rows[0] = q_indptr[1] - q_indptr[0]
                            batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                            while T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                while tile_id[0] >= batch_tiles[0] and batch_idx[0] < batch_size:
                                    tile_id[0] = tile_id[0] - batch_tiles[0]
                                    batch_idx[0] = batch_idx[0] + 1
                                    if batch_idx[0] < batch_size:
                                        b_idx: T.int32 = batch_idx[0]
                                        batch_rows[0] = q_indptr[b_idx + 1] - q_indptr[b_idx]
                                        batch_tiles[0] = (batch_rows[0] + 32 - 1) // 32
                                if T.tvm_thread_invariant(batch_idx[0] < batch_size):
                                    b_idx: T.int32 = batch_idx[0]
                                    LH_start: T.int32 = tile_id[0] * 32
                                    q_indptr_val: T.int32 = q_indptr[b_idx]
                                    cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
                                    cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
                                    kv_chunk_len[0] = T.if_then_else(cur_page_indptr_begin != cur_page_indptr_end, (cur_page_indptr_end - cur_page_indptr_begin - 1) * 16 + length_info[b_idx], 0)
                                    T.tvm_storage_sync("shared")
                                    for i in range(1):
                                        row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                        if row < 32:
                                            m_smem[row] = T.float32(-50000.0)
                                            d_smem[row] = T.float32(1.0)
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_init"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads()
                                                    T.writes(O_local[i, j])
                                                    O_local[i, j] = T.float32(0.0)
                                    T.tvm_storage_sync("shared")
                                    for li_lj_fused_0 in range(8):
                                        for li_lj_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_lj_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_lj_fused_3 in T.vectorized(4):
                                                    with T.block("Q_load"):
                                                        i = T.axis.spatial(32, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) // 128)
                                                        j = T.axis.spatial(128, (li_lj_fused_0 * 512 + li_lj_fused_1 * 128 + li_lj_fused_2 * 4 + li_lj_fused_3) % 128)
                                                        T.reads()
                                                        T.writes()
                                                        cur_L: T.int32 = q_indptr_val + (LH_start + i)
                                                        cur_H_qo: T.int32 = by
                                                        if cur_L < q_indptr[b_idx + 1]:
                                                            freq = T.float32()
                                                            Q_smem[i, j] = T.if_then_else(rotary_mode == 1, T.Let(T.Cast("float16", T.cos(freq) * T.Cast("float32", q[cur_L, cur_H_qo, j]) + T.sin(freq) * T.Cast("float32", T.if_then_else(j < 64, q[cur_L, cur_H_qo, j + 64] * T.float16(-1.0), q[cur_L, cur_H_qo, j - 64]))), where={freq: T.Cast("float32", q_rope_position[cur_L]) * rope_scale / T.pow(rope_theta, T.Cast("float32", j * 2 % 128) / T.float32(128.0))}), q[cur_L, cur_H_qo, j])
                                                        else:
                                                            Q_smem[i, j] = T.float16(0.0)
                                    T.tvm_storage_sync("shared")
                                    for iterator_1 in range((kv_chunk_len[0] + 31) // 32):
                                        L_kv_start: T.int32 = iterator_1 * 32
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("K_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                K_smem[i, j] = pages[page_no, 0, by, page_offset, j]
                                                            else:
                                                                K_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        for lz_ly_fused_0 in range(8):
                                            for lz_ly_fused_1 in T.thread_binding(4, thread="threadIdx.y"):
                                                for lz_ly_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lz_ly_fused_3 in T.vectorized(4):
                                                        with T.block("V_load"):
                                                            i = T.axis.spatial(32, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) // 128)
                                                            j = T.axis.spatial(128, (lz_ly_fused_0 * 512 + lz_ly_fused_1 * 128 + lz_ly_fused_2 * 4 + lz_ly_fused_3) % 128)
                                                            T.reads()
                                                            T.writes()
                                                            cur_L: T.int32 = L_kv_start + i
                                                            if cur_L < kv_chunk_len[0]:
                                                                seq_offset: T.int32 = cur_L
                                                                page_no: T.int32 = page_values[cur_page_indptr_begin + seq_offset // 16]
                                                                page_offset: T.int32 = seq_offset % 16
                                                                V_smem[i, j] = pages[page_no, 1, by, page_offset, j]
                                                            else:
                                                                V_smem[i, j] = T.float16(0.0)
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(Q_smem[0:32, 0:128], K_smem[0:32, 0:128])
                                            T.writes(S_local[0:32, 0:32])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(2, 4):
                                                        with T.block("S_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 8 * 2 + li_1_init)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 8 * 4 + lj_1_init)
                                                            T.reads()
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = T.float32(0.0)
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, li_1, lj_1, lk_1 in T.grid(16, 2, 4, 8):
                                                        with T.block("S_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                            j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                            k = T.axis.reduce(128, lk_0 * 8 + lk_1)
                                                            T.reads(S_local[i, j], Q_smem[i, k], K_smem[j, k])
                                                            T.writes(S_local[i, j])
                                                            S_local[i, j] = S_local[i, j] + T.Cast("float32", Q_smem[i, k]) * T.Cast("float32", K_smem[j, k]) * attn_score_scaling_factor * T.float32(0.12751743082459868)
                                        T.tvm_storage_sync("shared")
                                        for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                for li_1, lj_1 in T.grid(2, 4):
                                                    with T.block("S_store"):
                                                        i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 8 * 2 + li_1)
                                                        j = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 8 * 4 + lj_1)
                                                        T.reads(S_local[i, j])
                                                        T.writes(S_smem[i, j])
                                                        S_smem[i, j] = S_local[i, j]
                                        T.tvm_storage_sync("shared")
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update1"):
                                                    T.reads(m_smem[row], kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], m_new[i], S_smem[row, 0:32], d_smem[row], m_prev[i])
                                                    T.writes(m_prev[i], m_new[i], d_new[i])
                                                    m_prev[i] = m_smem[row]
                                                    m_new[i] = m_smem[row]
                                                    row_: T.int32 = LH_start + row
                                                    for j in range(32):
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            m_new[i] = T.max(m_new[i], S_smem[row, j])
                                                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            with T.block("update"):
                                                T.reads(kv_chunk_len[0], tree_order_indptr[b_idx:b_idx + 2], tree_order[T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]):T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0]) + (T.max(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] + 31 - kv_chunk_len[0]) + 1 - T.min(LH_start + row + tree_order_indptr[b_idx + 1] + q_indptr[b_idx] - q_indptr[b_idx + 1], L_kv_start + tree_order_indptr[b_idx + 1] - kv_chunk_len[0])), 0:2], q_indptr[b_idx:b_idx + 2], S_smem[row, 0:32], m_new[i])
                                                T.writes(S_smem[row, 0:32])
                                                for j in range(32):
                                                    if row < 32:
                                                        row_: T.int32 = LH_start + row
                                                        if L_kv_start + j < kv_chunk_len[0] and (L_kv_start + j < kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) or tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] >= tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 0] and tree_order[tree_order_indptr[b_idx] + (row_ + (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]) - (q_indptr[b_idx + 1] - q_indptr[b_idx])), 0] < tree_order[tree_order_indptr[b_idx] + (L_kv_start + j - (kv_chunk_len[0] - (tree_order_indptr[b_idx + 1] - tree_order_indptr[b_idx]))), 1]):
                                                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                                                        else:
                                                            S_smem[row, j] = T.exp2(T.float32(-50000.0) - m_new[i])
                                        for i in range(1):
                                            row: T.int32 = i * 32 * 4 + ty * 32 + tx
                                            if row < 32:
                                                with T.block("update"):
                                                    T.reads(d_new[i], S_smem[row, 0:32], m_new[i], m_prev[i])
                                                    T.writes(d_new[i], m_smem[row], d_smem[row], m_prev_smem[row])
                                                    for j in range(32):
                                                        d_new[i] = d_new[i] + S_smem[row, j]
                                                    m_smem[row] = m_new[i]
                                                    d_smem[row] = d_new[i]
                                                    m_prev_smem[row] = m_prev[i]
                                        T.tvm_storage_sync("shared")
                                        with T.block(""):
                                            T.reads(m_prev_smem[0:32], m_smem[0:32], S_smem[0:32, 0:32], V_smem[0:32, 0:128])
                                            T.writes(O_local[0:32, 0:128])
                                            for li_0_lj_0_fused_0_init in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1_init in T.thread_binding(32, thread="threadIdx.x"):
                                                    for li_1_init, lj_1_init in T.grid(4, 8):
                                                        with T.block("O_gemm_init"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) // 16 * 4 + li_1_init)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0_init * 32 + li_0_lj_0_fused_1_init) % 16 * 8 + lj_1_init)
                                                            T.reads()
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] * T.exp2(m_prev_smem[i] - m_smem[i])
                                            for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                                for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                                    for lk_0, lk_1, li_1, lj_1 in T.grid(4, 8, 4, 8):
                                                        with T.block("O_gemm_update"):
                                                            i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                            j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                            k = T.axis.reduce(32, lk_0 * 8 + lk_1)
                                                            T.reads(O_local[i, j], m_prev_smem[i], m_smem[i], S_smem[i, k], V_smem[k, j])
                                                            T.writes(O_local[i, j])
                                                            O_local[i, j] = O_local[i, j] + S_smem[i, k] * T.Cast("float32", V_smem[k, j])
                                    for li_0_lj_0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
                                        for li_0_lj_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
                                            for li_1, lj_1 in T.grid(4, 8):
                                                with T.block("O_store"):
                                                    i = T.axis.spatial(32, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) // 16 * 4 + li_1)
                                                    j = T.axis.spatial(128, (li_0_lj_0_fused_0 * 32 + li_0_lj_0_fused_1) % 16 * 8 + lj_1)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], O_local[i, j], d_smem[i])
                                                    T.writes(output[q_indptr[b_idx] + (LH_start + i), by, j])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        output[cur_L, cur_H_qo, j] = T.Cast("float16", O_local[i, j] / d_smem[i])
                                    for li_0 in range(1):
                                        for li_1 in T.thread_binding(4, thread="threadIdx.y"):
                                            for li_2 in T.thread_binding(32, thread="threadIdx.x"):
                                                with T.block("lse_store"):
                                                    i = T.axis.spatial(32, li_0 * 128 + li_1 * 32 + li_2)
                                                    T.where((li_0 * 4 + li_1) * 32 + li_2 < 32)
                                                    T.reads(q_indptr[b_idx:b_idx + 2], m_smem[i], d_smem[i])
                                                    T.writes(lse[q_indptr[b_idx] + (LH_start + i), by])
                                                    cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i)
                                                    cur_H_qo: T.int32 = by
                                                    if cur_L < q_indptr[b_idx + 1]:
                                                        lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])
                                    tile_id[0] = tile_id[0] + 16

    @R.function
    def alloc_embedding_tensor() -> R.Tensor((32768, 2048), dtype="float16"):
        R.func_attr({"relax.memory_plan_dynamic_func_output": True})
        gv: R.Tensor((32768, 2048), dtype="float16") = R.builtin.alloc_tensor(R.shape([32768, 2048]), R.dtype("float16"), R.prim_value(0), R.str("global"))
        return gv

    @R.function
    def argsort_probs(probs: R.Tensor(("batch_size", "vocab_size"), dtype="float32")) -> R.Tuple(R.Tensor(("batch_size", "vocab_size"), dtype="float32"), R.Tensor(("batch_size", "vocab_size"), dtype="int32")):
        batch_size = T.int64(is_size_var=True)
        vocab_size = T.int64(is_size_var=True)
        R.func_attr({"relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "num_positions": 480, "num_samples": 80}})
        cls = Module
        with R.dataflow():
            lv1 = R.call_tir(cls.argsort, (probs,), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="int32"))
            lv2 = R.call_tir(cls.take_sorted_probs, (probs, lv1), out_sinfo=R.Tensor((batch_size, vocab_size), dtype="float32"))
            gv1: R.Tuple(R.Tensor((batch_size, vocab_size), dtype="float32"), R.Tensor((batch_size, vocab_size), dtype="int32")) = lv2, lv1
            R.output(gv1)
        return gv1

    @R.function
    def batch_decode(input_embeds: R.Tensor(("batch_size", 1, 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"))) -> R.Tuple(R.Tensor(("batch_size", 1, 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 32768, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_layers_0_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[10]
            model_layers_0_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[11]
            model_layers_0_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[12]
            model_layers_0_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[13]
            model_layers_0_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[14]
            model_layers_0_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[15]
            model_layers_0_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[16]
            model_layers_0_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[17]
            model_layers_0_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[18]
            model_layers_1_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[19]
            model_layers_1_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[20]
            model_layers_1_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[21]
            model_layers_1_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[22]
            model_layers_1_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[23]
            model_layers_1_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[24]
            model_layers_1_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[25]
            model_layers_1_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[26]
            model_layers_1_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[27]
            model_layers_1_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[28]
            model_layers_1_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[29]
            model_layers_1_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[30]
            model_layers_1_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[31]
            model_layers_1_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[32]
            model_layers_1_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[33]
            model_layers_1_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_1_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[35]
            model_layers_2_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[36]
            model_layers_2_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[37]
            model_layers_2_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[38]
            model_layers_2_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[39]
            model_layers_2_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[40]
            model_layers_2_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[41]
            model_layers_2_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[42]
            model_layers_2_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[43]
            model_layers_2_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[44]
            model_layers_2_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[45]
            model_layers_2_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[46]
            model_layers_2_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[47]
            model_layers_2_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[48]
            model_layers_2_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[49]
            model_layers_2_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[50]
            model_layers_2_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[51]
            model_layers_2_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[52]
            model_layers_3_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[53]
            model_layers_3_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[54]
            model_layers_3_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[55]
            model_layers_3_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[56]
            model_layers_3_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[57]
            model_layers_3_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[58]
            model_layers_3_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[59]
            model_layers_3_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[60]
            model_layers_3_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[61]
            model_layers_3_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[62]
            model_layers_3_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[63]
            model_layers_3_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[64]
            model_layers_3_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[65]
            model_layers_3_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[66]
            model_layers_3_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[67]
            model_layers_3_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[68]
            model_layers_3_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[69]
            model_layers_4_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[70]
            model_layers_4_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[71]
            model_layers_4_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[72]
            model_layers_4_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[73]
            model_layers_4_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[74]
            model_layers_4_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[75]
            model_layers_4_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[76]
            model_layers_4_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[77]
            model_layers_4_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[78]
            model_layers_4_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[79]
            model_layers_4_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[80]
            model_layers_4_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[81]
            model_layers_4_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[82]
            model_layers_4_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[83]
            model_layers_4_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[84]
            model_layers_4_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[85]
            model_layers_4_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[86]
            model_layers_5_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[87]
            model_layers_5_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[88]
            model_layers_5_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[89]
            model_layers_5_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[90]
            model_layers_5_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[91]
            model_layers_5_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[92]
            model_layers_5_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[93]
            model_layers_5_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[94]
            model_layers_5_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[95]
            model_layers_5_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[96]
            model_layers_5_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[97]
            model_layers_5_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[98]
            model_layers_5_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[99]
            model_layers_5_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[100]
            model_layers_5_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[101]
            model_layers_5_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[102]
            model_layers_5_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[103]
            model_layers_6_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[104]
            model_layers_6_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[105]
            model_layers_6_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[106]
            model_layers_6_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[107]
            model_layers_6_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[108]
            model_layers_6_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[109]
            model_layers_6_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[110]
            model_layers_6_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[111]
            model_layers_6_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[112]
            model_layers_6_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[113]
            model_layers_6_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[114]
            model_layers_6_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[115]
            model_layers_6_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[116]
            model_layers_6_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[117]
            model_layers_6_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[118]
            model_layers_6_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[119]
            model_layers_6_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[120]
            model_layers_7_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[121]
            model_layers_7_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[122]
            model_layers_7_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[123]
            model_layers_7_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[124]
            model_layers_7_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[125]
            model_layers_7_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[126]
            model_layers_7_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[127]
            model_layers_7_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[128]
            model_layers_7_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[129]
            model_layers_7_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[130]
            model_layers_7_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[131]
            model_layers_7_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[132]
            model_layers_7_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[133]
            model_layers_7_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[134]
            model_layers_7_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[135]
            model_layers_7_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[136]
            model_layers_7_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[137]
            model_layers_8_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[138]
            model_layers_8_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[139]
            model_layers_8_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[140]
            model_layers_8_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[141]
            model_layers_8_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[142]
            model_layers_8_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[143]
            model_layers_8_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[144]
            model_layers_8_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[145]
            model_layers_8_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[146]
            model_layers_8_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[147]
            model_layers_8_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[148]
            model_layers_8_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[149]
            model_layers_8_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[150]
            model_layers_8_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[151]
            model_layers_8_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[152]
            model_layers_8_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[153]
            model_layers_8_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_9_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[155]
            model_layers_9_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[156]
            model_layers_9_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[157]
            model_layers_9_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[158]
            model_layers_9_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[159]
            model_layers_9_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[160]
            model_layers_9_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[161]
            model_layers_9_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[162]
            model_layers_9_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[163]
            model_layers_9_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[164]
            model_layers_9_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[165]
            model_layers_9_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[166]
            model_layers_9_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[167]
            model_layers_9_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[168]
            model_layers_9_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[169]
            model_layers_9_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[170]
            model_layers_9_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[171]
            model_layers_10_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[172]
            model_layers_10_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[173]
            model_layers_10_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[174]
            model_layers_10_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[175]
            model_layers_10_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[176]
            model_layers_10_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[177]
            model_layers_10_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[178]
            model_layers_10_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[179]
            model_layers_10_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[180]
            model_layers_10_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[181]
            model_layers_10_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[182]
            model_layers_10_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[183]
            model_layers_10_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[184]
            model_layers_10_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[185]
            model_layers_10_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[186]
            model_layers_10_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_10_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_11_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[189]
            model_layers_11_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[190]
            model_layers_11_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[191]
            model_layers_11_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_11_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_11_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[194]
            model_layers_11_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[195]
            model_layers_11_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[196]
            model_layers_11_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[197]
            model_layers_11_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[198]
            model_layers_11_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[199]
            model_layers_11_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[200]
            model_layers_11_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[201]
            model_layers_11_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[202]
            model_layers_11_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[203]
            model_layers_11_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[204]
            model_layers_11_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[205]
            model_layers_12_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[206]
            model_layers_12_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[207]
            model_layers_12_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[208]
            model_layers_12_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[209]
            model_layers_12_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[210]
            model_layers_12_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[211]
            model_layers_12_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[212]
            model_layers_12_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[213]
            model_layers_12_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[214]
            model_layers_12_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[215]
            model_layers_12_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[216]
            model_layers_12_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[217]
            model_layers_12_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[218]
            model_layers_12_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[219]
            model_layers_12_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[220]
            model_layers_12_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_12_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[222]
            model_layers_13_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[223]
            model_layers_13_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[224]
            model_layers_13_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[225]
            model_layers_13_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[226]
            model_layers_13_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[227]
            model_layers_13_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[228]
            model_layers_13_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[229]
            model_layers_13_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[230]
            model_layers_13_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[231]
            model_layers_13_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[232]
            model_layers_13_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[233]
            model_layers_13_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[234]
            model_layers_13_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[235]
            model_layers_13_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[236]
            model_layers_13_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[237]
            model_layers_13_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[238]
            model_layers_13_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[239]
            model_layers_14_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[240]
            model_layers_14_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[241]
            model_layers_14_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[242]
            model_layers_14_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[243]
            model_layers_14_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[244]
            model_layers_14_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[245]
            model_layers_14_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[246]
            model_layers_14_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[247]
            model_layers_14_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[248]
            model_layers_14_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[249]
            model_layers_14_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[250]
            model_layers_14_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[251]
            model_layers_14_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[252]
            model_layers_14_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[253]
            model_layers_14_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[254]
            model_layers_14_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[255]
            model_layers_14_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[256]
            model_layers_15_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[257]
            model_layers_15_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[258]
            model_layers_15_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[259]
            model_layers_15_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[260]
            model_layers_15_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[261]
            model_layers_15_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[262]
            model_layers_15_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[263]
            model_layers_15_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[264]
            model_layers_15_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[265]
            model_layers_15_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[266]
            model_layers_15_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[267]
            model_layers_15_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[268]
            model_layers_15_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[269]
            model_layers_15_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[270]
            model_layers_15_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[271]
            model_layers_15_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[272]
            model_layers_15_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[273]
            model_layers_16_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[274]
            model_layers_16_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[275]
            model_layers_16_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[276]
            model_layers_16_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[277]
            model_layers_16_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[278]
            model_layers_16_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[279]
            model_layers_16_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[280]
            model_layers_16_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[281]
            model_layers_16_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[282]
            model_layers_16_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[283]
            model_layers_16_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[284]
            model_layers_16_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[285]
            model_layers_16_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[286]
            model_layers_16_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[287]
            model_layers_16_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[288]
            model_layers_16_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[289]
            model_layers_16_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[290]
            model_layers_17_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[291]
            model_layers_17_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[292]
            model_layers_17_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[293]
            model_layers_17_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[294]
            model_layers_17_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[295]
            model_layers_17_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[296]
            model_layers_17_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[297]
            model_layers_17_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[298]
            model_layers_17_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[299]
            model_layers_17_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[300]
            model_layers_17_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[301]
            model_layers_17_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[302]
            model_layers_17_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[303]
            model_layers_17_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[304]
            model_layers_17_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[305]
            model_layers_17_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[306]
            model_layers_17_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[307]
            model_layers_18_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[308]
            model_layers_18_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[309]
            model_layers_18_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[310]
            model_layers_18_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[311]
            model_layers_18_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[312]
            model_layers_18_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[313]
            model_layers_18_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[314]
            model_layers_18_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[315]
            model_layers_18_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[316]
            model_layers_18_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[317]
            model_layers_18_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[318]
            model_layers_18_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[319]
            model_layers_18_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[320]
            model_layers_18_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[321]
            model_layers_18_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[322]
            model_layers_18_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[323]
            model_layers_18_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[324]
            model_layers_19_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[325]
            model_layers_19_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[326]
            model_layers_19_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[327]
            model_layers_19_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[328]
            model_layers_19_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[329]
            model_layers_19_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[330]
            model_layers_19_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[331]
            model_layers_19_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[332]
            model_layers_19_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[333]
            model_layers_19_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[334]
            model_layers_19_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[335]
            model_layers_19_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[336]
            model_layers_19_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[337]
            model_layers_19_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[338]
            model_layers_19_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[339]
            model_layers_19_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[340]
            model_layers_19_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_20_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[342]
            model_layers_20_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[343]
            model_layers_20_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[344]
            model_layers_20_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[345]
            model_layers_20_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[346]
            model_layers_20_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[347]
            model_layers_20_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[348]
            model_layers_20_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[349]
            model_layers_20_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[350]
            model_layers_20_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[351]
            model_layers_20_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[352]
            model_layers_20_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[353]
            model_layers_20_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[354]
            model_layers_20_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[355]
            model_layers_20_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[356]
            model_layers_20_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[357]
            model_layers_20_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[358]
            model_layers_21_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[359]
            model_layers_21_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[360]
            model_layers_21_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[361]
            model_layers_21_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[362]
            model_layers_21_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[363]
            model_layers_21_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[364]
            model_layers_21_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[365]
            model_layers_21_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[366]
            model_layers_21_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[367]
            model_layers_21_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[368]
            model_layers_21_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[369]
            model_layers_21_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[370]
            model_layers_21_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[371]
            model_layers_21_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[372]
            model_layers_21_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[373]
            model_layers_21_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_21_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_22_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[376]
            model_layers_22_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[377]
            model_layers_22_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[378]
            model_layers_22_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_22_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_22_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[381]
            model_layers_22_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[382]
            model_layers_22_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[383]
            model_layers_22_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[384]
            model_layers_22_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[385]
            model_layers_22_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[386]
            model_layers_22_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[387]
            model_layers_22_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[388]
            model_layers_22_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[389]
            model_layers_22_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[390]
            model_layers_22_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[391]
            model_layers_22_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[392]
            model_layers_23_self_attn_c_attn_q_weight4: R.Tensor((6144, 256), dtype="uint32") = packed_params[393]
            model_layers_23_self_attn_c_attn_q_scale4: R.Tensor((6144, 64), dtype="float16") = packed_params[394]
            model_layers_23_self_attn_c_attn_bias4: R.Tensor((6144,), dtype="float16") = packed_params[395]
            model_layers_23_self_attn_o_proj_q_weight4: R.Tensor((2048, 256), dtype="uint32") = packed_params[396]
            model_layers_23_self_attn_o_proj_q_scale4: R.Tensor((2048, 64), dtype="float16") = packed_params[397]
            model_layers_23_mlp_shared_expert_gate_up_proj_q_weight4: R.Tensor((11264, 256), dtype="uint32") = packed_params[398]
            model_layers_23_mlp_shared_expert_gate_up_proj_q_scale4: R.Tensor((11264, 64), dtype="float16") = packed_params[399]
            model_layers_23_mlp_shared_expert_down_proj_q_weight4: R.Tensor((2048, 704), dtype="uint32") = packed_params[400]
            model_layers_23_mlp_shared_expert_down_proj_q_scale4: R.Tensor((2048, 176), dtype="float16") = packed_params[401]
            model_layers_23_mlp_shared_expert_gate_weight4: R.Tensor((1, 2048), dtype="float16") = packed_params[402]
            model_layers_23_mlp_gate_weight4: R.Tensor((60, 2048), dtype="float16") = packed_params[403]
            model_layers_23_mlp_moe_gate_up_proj_q_weight4: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[404]
            model_layers_23_mlp_moe_gate_up_proj_q_scale4: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[405]
            model_layers_23_mlp_moe_down_proj_q_weight4: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[406]
            model_layers_23_mlp_moe_down_proj_q_scale4: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[407]
            model_layers_23_input_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[408]
            model_layers_23_post_attention_layernorm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[409]
            model_norm_weight4: R.Tensor((2048,), dtype="float16") = packed_params[410]
            lm_head_q_weight4: R.Tensor((151936, 256), dtype="uint32") = packed_params[411]
            lm_head_q_scale4: R.Tensor((151936, 64), dtype="float16") = packed_params[412]
            rms_norm147 = R.call_tir(cls.rms_norm, (input_embeds, model_layers_0_input_layernorm_weight4), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_0_self_attn_c_attn_q_weight4, model_layers_0_self_attn_c_attn_q_scale4, rms_norm147, model_layers_0_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape624 = R.call_tir(cls.reshape, (lv,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape625 = R.call_tir(cls.reshape1, (reshape624,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv774 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape625), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape626 = R.call_tir(cls.reshape2, (lv774,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape627 = R.call_tir(cls.reshape3, (reshape626,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_0_self_attn_o_proj_q_weight4, model_layers_0_self_attn_o_proj_q_scale4, reshape627), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv_2 = R.call_tir(cls.fuse_add_norm_decode, (lv_1, input_embeds, model_layers_0_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_2[1]
            rms_norm148: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv_2[0]
            reshape628 = R.call_tir(cls.reshape4, (rms_norm148,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv1_1 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape628, model_layers_0_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv2 = R.call_tir(cls.fused_softmax_cast1, (lv1_1,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv776 = R.call_tir(cls.top4_softmax, (lv2,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_072: R.Tensor((batch_size, 4), dtype="float16") = lv776[0]
            top4_softmax_172: R.Tensor((batch_size, 4), dtype="int32") = lv776[1]
            lv3 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_172,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape629 = R.call_tir(cls.reshape5, (lv3,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape629, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv1_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum48: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv1_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv778 = R.call_tir(cls.get_indices, (cumsum48, top4_softmax_172), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_048: R.Tensor((batch_size * 4,), dtype="int32") = lv778[0]
            get_indices_148: R.Tensor((batch_size * 4,), dtype="int32") = lv778[1]
            lv779 = R.call_tir(cls.get_expert_instance_indptr, (cumsum48,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take50 = R.call_tir(cls.take, (reshape628, get_indices_148), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv780 = R.call_tir(cls.dequantize_group_gemm, (take50, model_layers_0_mlp_moe_gate_up_proj_q_weight4, model_layers_0_mlp_moe_gate_up_proj_q_scale4, lv779), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv4 = R.call_tir(cls.fused_split_silu_multiply, (lv780,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv781 = R.call_tir(cls.dequantize_group_gemm1, (lv4, model_layers_0_mlp_moe_down_proj_q_weight4, model_layers_0_mlp_moe_down_proj_q_scale4, lv779), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv782 = R.call_tir(cls.scatter_output, (lv781, get_indices_048), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape630 = R.call_tir(cls.reshape6, (top4_softmax_072,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape631 = R.call_tir(cls.reshape7, (lv782,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv5 = R.call_tir(cls.fused_multiply1_sum, (reshape631, reshape630), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv1_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_0_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_0_mlp_shared_expert_gate_up_proj_q_scale4, reshape628), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv6 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv1_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv7 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape628, model_layers_0_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv_4 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_0_mlp_shared_expert_down_proj_q_weight4, model_layers_0_mlp_shared_expert_down_proj_q_scale4, lv6, lv7, lv5), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape632 = R.call_tir(cls.reshape8, (lv_4,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv2_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape632, lv1, model_layers_1_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv3_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[1]
            rms_norm149: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv2_1[0]
            lv1_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_1_self_attn_c_attn_q_weight4, model_layers_1_self_attn_c_attn_q_scale4, rms_norm149, model_layers_1_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape633 = R.call_tir(cls.reshape, (lv1_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape634 = R.call_tir(cls.reshape1, (reshape633,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv786 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape634), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape635 = R.call_tir(cls.reshape2, (lv786,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape636 = R.call_tir(cls.reshape3, (reshape635,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv2_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_1_self_attn_o_proj_q_weight4, model_layers_1_self_attn_o_proj_q_scale4, reshape636), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv4_1 = R.call_tir(cls.fuse_add_norm_decode, (lv2_2, lv3_1, model_layers_1_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv5_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4_1[1]
            rms_norm150: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv4_1[0]
            reshape637 = R.call_tir(cls.reshape4, (rms_norm150,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv10 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape637, model_layers_1_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv11 = R.call_tir(cls.fused_softmax_cast1, (lv10,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv788 = R.call_tir(cls.top4_softmax, (lv11,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_073: R.Tensor((batch_size, 4), dtype="float16") = lv788[0]
            top4_softmax_173: R.Tensor((batch_size, 4), dtype="int32") = lv788[1]
            lv12 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_173,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape638 = R.call_tir(cls.reshape5, (lv12,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv2_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape638, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv3_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv2_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum49: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv3_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv790 = R.call_tir(cls.get_indices, (cumsum49, top4_softmax_173), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_049: R.Tensor((batch_size * 4,), dtype="int32") = lv790[0]
            get_indices_149: R.Tensor((batch_size * 4,), dtype="int32") = lv790[1]
            lv791 = R.call_tir(cls.get_expert_instance_indptr, (cumsum49,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take51 = R.call_tir(cls.take, (reshape637, get_indices_149), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv792 = R.call_tir(cls.dequantize_group_gemm, (take51, model_layers_1_mlp_moe_gate_up_proj_q_weight4, model_layers_1_mlp_moe_gate_up_proj_q_scale4, lv791), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv13 = R.call_tir(cls.fused_split_silu_multiply, (lv792,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv793 = R.call_tir(cls.dequantize_group_gemm1, (lv13, model_layers_1_mlp_moe_down_proj_q_weight4, model_layers_1_mlp_moe_down_proj_q_scale4, lv791), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv794 = R.call_tir(cls.scatter_output, (lv793, get_indices_049), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape639 = R.call_tir(cls.reshape6, (top4_softmax_073,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape640 = R.call_tir(cls.reshape7, (lv794,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv14 = R.call_tir(cls.fused_multiply1_sum, (reshape640, reshape639), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv3_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_1_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_1_mlp_shared_expert_gate_up_proj_q_scale4, reshape637), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv15 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv3_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv16 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape637, model_layers_1_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv1_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_1_mlp_shared_expert_down_proj_q_weight4, model_layers_1_mlp_shared_expert_down_proj_q_scale4, lv15, lv16, lv14), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape641 = R.call_tir(cls.reshape8, (lv1_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape641, lv5_1, model_layers_2_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv7_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6_1[1]
            rms_norm151: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv6_1[0]
            lv2_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_2_self_attn_c_attn_q_weight4, model_layers_2_self_attn_c_attn_q_scale4, rms_norm151, model_layers_2_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape642 = R.call_tir(cls.reshape, (lv2_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape643 = R.call_tir(cls.reshape1, (reshape642,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv798 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape643), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape644 = R.call_tir(cls.reshape2, (lv798,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape645 = R.call_tir(cls.reshape3, (reshape644,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv4_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_2_self_attn_o_proj_q_weight4, model_layers_2_self_attn_o_proj_q_scale4, reshape645), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv8 = R.call_tir(cls.fuse_add_norm_decode, (lv4_2, lv7_1, model_layers_2_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv9: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8[1]
            rms_norm152: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv8[0]
            reshape646 = R.call_tir(cls.reshape4, (rms_norm152,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv19 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape646, model_layers_2_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv20 = R.call_tir(cls.fused_softmax_cast1, (lv19,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv800 = R.call_tir(cls.top4_softmax, (lv20,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_074: R.Tensor((batch_size, 4), dtype="float16") = lv800[0]
            top4_softmax_174: R.Tensor((batch_size, 4), dtype="int32") = lv800[1]
            lv21 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_174,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape647 = R.call_tir(cls.reshape5, (lv21,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv4_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape647, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv5_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv4_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum50: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv5_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv802 = R.call_tir(cls.get_indices, (cumsum50, top4_softmax_174), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_050: R.Tensor((batch_size * 4,), dtype="int32") = lv802[0]
            get_indices_150: R.Tensor((batch_size * 4,), dtype="int32") = lv802[1]
            lv803 = R.call_tir(cls.get_expert_instance_indptr, (cumsum50,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take52 = R.call_tir(cls.take, (reshape646, get_indices_150), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv804 = R.call_tir(cls.dequantize_group_gemm, (take52, model_layers_2_mlp_moe_gate_up_proj_q_weight4, model_layers_2_mlp_moe_gate_up_proj_q_scale4, lv803), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv22 = R.call_tir(cls.fused_split_silu_multiply, (lv804,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv805 = R.call_tir(cls.dequantize_group_gemm1, (lv22, model_layers_2_mlp_moe_down_proj_q_weight4, model_layers_2_mlp_moe_down_proj_q_scale4, lv803), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv806 = R.call_tir(cls.scatter_output, (lv805, get_indices_050), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape648 = R.call_tir(cls.reshape6, (top4_softmax_074,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape649 = R.call_tir(cls.reshape7, (lv806,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv23 = R.call_tir(cls.fused_multiply1_sum, (reshape649, reshape648), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv5_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_2_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_2_mlp_shared_expert_gate_up_proj_q_scale4, reshape646), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv24 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv5_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv25 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape646, model_layers_2_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv2_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_2_mlp_shared_expert_down_proj_q_weight4, model_layers_2_mlp_shared_expert_down_proj_q_scale4, lv24, lv25, lv23), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape650 = R.call_tir(cls.reshape8, (lv2_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv10_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape650, lv9, model_layers_3_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv11_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10_1[1]
            rms_norm153: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv10_1[0]
            lv3_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_3_self_attn_c_attn_q_weight4, model_layers_3_self_attn_c_attn_q_scale4, rms_norm153, model_layers_3_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape651 = R.call_tir(cls.reshape, (lv3_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape652 = R.call_tir(cls.reshape1, (reshape651,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv810 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape652), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape653 = R.call_tir(cls.reshape2, (lv810,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape654 = R.call_tir(cls.reshape3, (reshape653,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv6_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_3_self_attn_o_proj_q_weight4, model_layers_3_self_attn_o_proj_q_scale4, reshape654), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12_1 = R.call_tir(cls.fuse_add_norm_decode, (lv6_2, lv11_1, model_layers_3_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv13_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12_1[1]
            rms_norm154: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv12_1[0]
            reshape655 = R.call_tir(cls.reshape4, (rms_norm154,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv28 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape655, model_layers_3_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv29 = R.call_tir(cls.fused_softmax_cast1, (lv28,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv812 = R.call_tir(cls.top4_softmax, (lv29,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_075: R.Tensor((batch_size, 4), dtype="float16") = lv812[0]
            top4_softmax_175: R.Tensor((batch_size, 4), dtype="int32") = lv812[1]
            lv30 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_175,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape656 = R.call_tir(cls.reshape5, (lv30,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv6_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape656, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv7_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv6_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum51: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv7_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv814 = R.call_tir(cls.get_indices, (cumsum51, top4_softmax_175), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_051: R.Tensor((batch_size * 4,), dtype="int32") = lv814[0]
            get_indices_151: R.Tensor((batch_size * 4,), dtype="int32") = lv814[1]
            lv815 = R.call_tir(cls.get_expert_instance_indptr, (cumsum51,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take53 = R.call_tir(cls.take, (reshape655, get_indices_151), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv816 = R.call_tir(cls.dequantize_group_gemm, (take53, model_layers_3_mlp_moe_gate_up_proj_q_weight4, model_layers_3_mlp_moe_gate_up_proj_q_scale4, lv815), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv31 = R.call_tir(cls.fused_split_silu_multiply, (lv816,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv817 = R.call_tir(cls.dequantize_group_gemm1, (lv31, model_layers_3_mlp_moe_down_proj_q_weight4, model_layers_3_mlp_moe_down_proj_q_scale4, lv815), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv818 = R.call_tir(cls.scatter_output, (lv817, get_indices_051), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape657 = R.call_tir(cls.reshape6, (top4_softmax_075,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape658 = R.call_tir(cls.reshape7, (lv818,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv32 = R.call_tir(cls.fused_multiply1_sum, (reshape658, reshape657), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv7_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_3_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_3_mlp_shared_expert_gate_up_proj_q_scale4, reshape655), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv33 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv7_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv34 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape655, model_layers_3_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv3_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_3_mlp_shared_expert_down_proj_q_weight4, model_layers_3_mlp_shared_expert_down_proj_q_scale4, lv33, lv34, lv32), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape659 = R.call_tir(cls.reshape8, (lv3_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv14_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape659, lv13_1, model_layers_4_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv15_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14_1[1]
            rms_norm155: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv14_1[0]
            lv4_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_4_self_attn_c_attn_q_weight4, model_layers_4_self_attn_c_attn_q_scale4, rms_norm155, model_layers_4_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape660 = R.call_tir(cls.reshape, (lv4_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape661 = R.call_tir(cls.reshape1, (reshape660,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv822 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape661), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape662 = R.call_tir(cls.reshape2, (lv822,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape663 = R.call_tir(cls.reshape3, (reshape662,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv8_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_4_self_attn_o_proj_q_weight4, model_layers_4_self_attn_o_proj_q_scale4, reshape663), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv16_1 = R.call_tir(cls.fuse_add_norm_decode, (lv8_1, lv15_1, model_layers_4_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv17: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16_1[1]
            rms_norm156: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv16_1[0]
            reshape664 = R.call_tir(cls.reshape4, (rms_norm156,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv37 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape664, model_layers_4_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv38 = R.call_tir(cls.fused_softmax_cast1, (lv37,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv824 = R.call_tir(cls.top4_softmax, (lv38,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_076: R.Tensor((batch_size, 4), dtype="float16") = lv824[0]
            top4_softmax_176: R.Tensor((batch_size, 4), dtype="int32") = lv824[1]
            lv39 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_176,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape665 = R.call_tir(cls.reshape5, (lv39,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv8_2: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape665, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv9_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv8_2,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum52: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv9_1, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv826 = R.call_tir(cls.get_indices, (cumsum52, top4_softmax_176), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_052: R.Tensor((batch_size * 4,), dtype="int32") = lv826[0]
            get_indices_152: R.Tensor((batch_size * 4,), dtype="int32") = lv826[1]
            lv827 = R.call_tir(cls.get_expert_instance_indptr, (cumsum52,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take54 = R.call_tir(cls.take, (reshape664, get_indices_152), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv828 = R.call_tir(cls.dequantize_group_gemm, (take54, model_layers_4_mlp_moe_gate_up_proj_q_weight4, model_layers_4_mlp_moe_gate_up_proj_q_scale4, lv827), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv40 = R.call_tir(cls.fused_split_silu_multiply, (lv828,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv829 = R.call_tir(cls.dequantize_group_gemm1, (lv40, model_layers_4_mlp_moe_down_proj_q_weight4, model_layers_4_mlp_moe_down_proj_q_scale4, lv827), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv830 = R.call_tir(cls.scatter_output, (lv829, get_indices_052), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape666 = R.call_tir(cls.reshape6, (top4_softmax_076,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape667 = R.call_tir(cls.reshape7, (lv830,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv41 = R.call_tir(cls.fused_multiply1_sum, (reshape667, reshape666), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv9_2 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_4_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_4_mlp_shared_expert_gate_up_proj_q_scale4, reshape664), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv42 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv9_2,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv43 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape664, model_layers_4_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv4_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_4_mlp_shared_expert_down_proj_q_weight4, model_layers_4_mlp_shared_expert_down_proj_q_scale4, lv42, lv43, lv41), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape668 = R.call_tir(cls.reshape8, (lv4_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18 = R.call_tir(cls.fuse_add_norm_decode, (reshape668, lv17, model_layers_5_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv19_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18[1]
            rms_norm157: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv18[0]
            lv5_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_5_self_attn_c_attn_q_weight4, model_layers_5_self_attn_c_attn_q_scale4, rms_norm157, model_layers_5_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape669 = R.call_tir(cls.reshape, (lv5_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape670 = R.call_tir(cls.reshape1, (reshape669,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv834 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape670), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape671 = R.call_tir(cls.reshape2, (lv834,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape672 = R.call_tir(cls.reshape3, (reshape671,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv10_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_5_self_attn_o_proj_q_weight4, model_layers_5_self_attn_o_proj_q_scale4, reshape672), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv20_1 = R.call_tir(cls.fuse_add_norm_decode, (lv10_2, lv19_1, model_layers_5_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv21_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20_1[1]
            rms_norm158: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv20_1[0]
            reshape673 = R.call_tir(cls.reshape4, (rms_norm158,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv46 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape673, model_layers_5_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv47 = R.call_tir(cls.fused_softmax_cast1, (lv46,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv836 = R.call_tir(cls.top4_softmax, (lv47,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_077: R.Tensor((batch_size, 4), dtype="float16") = lv836[0]
            top4_softmax_177: R.Tensor((batch_size, 4), dtype="int32") = lv836[1]
            lv48 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_177,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape674 = R.call_tir(cls.reshape5, (lv48,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv10_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape674, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv11_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv10_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum53: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv11_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv838 = R.call_tir(cls.get_indices, (cumsum53, top4_softmax_177), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_053: R.Tensor((batch_size * 4,), dtype="int32") = lv838[0]
            get_indices_153: R.Tensor((batch_size * 4,), dtype="int32") = lv838[1]
            lv839 = R.call_tir(cls.get_expert_instance_indptr, (cumsum53,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take55 = R.call_tir(cls.take, (reshape673, get_indices_153), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv840 = R.call_tir(cls.dequantize_group_gemm, (take55, model_layers_5_mlp_moe_gate_up_proj_q_weight4, model_layers_5_mlp_moe_gate_up_proj_q_scale4, lv839), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv49 = R.call_tir(cls.fused_split_silu_multiply, (lv840,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv841 = R.call_tir(cls.dequantize_group_gemm1, (lv49, model_layers_5_mlp_moe_down_proj_q_weight4, model_layers_5_mlp_moe_down_proj_q_scale4, lv839), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv842 = R.call_tir(cls.scatter_output, (lv841, get_indices_053), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape675 = R.call_tir(cls.reshape6, (top4_softmax_077,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape676 = R.call_tir(cls.reshape7, (lv842,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv50 = R.call_tir(cls.fused_multiply1_sum, (reshape676, reshape675), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv11_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_5_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_5_mlp_shared_expert_gate_up_proj_q_scale4, reshape673), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv51 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv11_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv52 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape673, model_layers_5_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv5_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_5_mlp_shared_expert_down_proj_q_weight4, model_layers_5_mlp_shared_expert_down_proj_q_scale4, lv51, lv52, lv50), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape677 = R.call_tir(cls.reshape8, (lv5_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv22_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape677, lv21_1, model_layers_6_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv23_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22_1[1]
            rms_norm159: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv22_1[0]
            lv6_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_6_self_attn_c_attn_q_weight4, model_layers_6_self_attn_c_attn_q_scale4, rms_norm159, model_layers_6_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape678 = R.call_tir(cls.reshape, (lv6_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape679 = R.call_tir(cls.reshape1, (reshape678,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv846 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape679), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape680 = R.call_tir(cls.reshape2, (lv846,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape681 = R.call_tir(cls.reshape3, (reshape680,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv12_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_6_self_attn_o_proj_q_weight4, model_layers_6_self_attn_o_proj_q_scale4, reshape681), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24_1 = R.call_tir(cls.fuse_add_norm_decode, (lv12_2, lv23_1, model_layers_6_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv25_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24_1[1]
            rms_norm160: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv24_1[0]
            reshape682 = R.call_tir(cls.reshape4, (rms_norm160,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv55 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape682, model_layers_6_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv56 = R.call_tir(cls.fused_softmax_cast1, (lv55,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv848 = R.call_tir(cls.top4_softmax, (lv56,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_078: R.Tensor((batch_size, 4), dtype="float16") = lv848[0]
            top4_softmax_178: R.Tensor((batch_size, 4), dtype="int32") = lv848[1]
            lv57 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_178,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape683 = R.call_tir(cls.reshape5, (lv57,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv12_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape683, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv13_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv12_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum54: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv13_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv850 = R.call_tir(cls.get_indices, (cumsum54, top4_softmax_178), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_054: R.Tensor((batch_size * 4,), dtype="int32") = lv850[0]
            get_indices_154: R.Tensor((batch_size * 4,), dtype="int32") = lv850[1]
            lv851 = R.call_tir(cls.get_expert_instance_indptr, (cumsum54,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take56 = R.call_tir(cls.take, (reshape682, get_indices_154), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv852 = R.call_tir(cls.dequantize_group_gemm, (take56, model_layers_6_mlp_moe_gate_up_proj_q_weight4, model_layers_6_mlp_moe_gate_up_proj_q_scale4, lv851), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv58 = R.call_tir(cls.fused_split_silu_multiply, (lv852,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv853 = R.call_tir(cls.dequantize_group_gemm1, (lv58, model_layers_6_mlp_moe_down_proj_q_weight4, model_layers_6_mlp_moe_down_proj_q_scale4, lv851), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv854 = R.call_tir(cls.scatter_output, (lv853, get_indices_054), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape684 = R.call_tir(cls.reshape6, (top4_softmax_078,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape685 = R.call_tir(cls.reshape7, (lv854,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv59 = R.call_tir(cls.fused_multiply1_sum, (reshape685, reshape684), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv13_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_6_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_6_mlp_shared_expert_gate_up_proj_q_scale4, reshape682), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv60 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv13_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv61 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape682, model_layers_6_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv6_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_6_mlp_shared_expert_down_proj_q_weight4, model_layers_6_mlp_shared_expert_down_proj_q_scale4, lv60, lv61, lv59), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape686 = R.call_tir(cls.reshape8, (lv6_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv26 = R.call_tir(cls.fuse_add_norm_decode, (reshape686, lv25_1, model_layers_7_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv27: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26[1]
            rms_norm161: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv26[0]
            lv7_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_7_self_attn_c_attn_q_weight4, model_layers_7_self_attn_c_attn_q_scale4, rms_norm161, model_layers_7_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape687 = R.call_tir(cls.reshape, (lv7_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape688 = R.call_tir(cls.reshape1, (reshape687,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv858 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape688), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape689 = R.call_tir(cls.reshape2, (lv858,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape690 = R.call_tir(cls.reshape3, (reshape689,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv14_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_7_self_attn_o_proj_q_weight4, model_layers_7_self_attn_o_proj_q_scale4, reshape690), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv28_1 = R.call_tir(cls.fuse_add_norm_decode, (lv14_2, lv27, model_layers_7_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv29_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28_1[1]
            rms_norm162: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv28_1[0]
            reshape691 = R.call_tir(cls.reshape4, (rms_norm162,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv64 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape691, model_layers_7_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv65 = R.call_tir(cls.fused_softmax_cast1, (lv64,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv860 = R.call_tir(cls.top4_softmax, (lv65,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_079: R.Tensor((batch_size, 4), dtype="float16") = lv860[0]
            top4_softmax_179: R.Tensor((batch_size, 4), dtype="int32") = lv860[1]
            lv66 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_179,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape692 = R.call_tir(cls.reshape5, (lv66,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv14_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape692, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv15_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv14_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum55: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv15_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv862 = R.call_tir(cls.get_indices, (cumsum55, top4_softmax_179), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_055: R.Tensor((batch_size * 4,), dtype="int32") = lv862[0]
            get_indices_155: R.Tensor((batch_size * 4,), dtype="int32") = lv862[1]
            lv863 = R.call_tir(cls.get_expert_instance_indptr, (cumsum55,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take57 = R.call_tir(cls.take, (reshape691, get_indices_155), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv864 = R.call_tir(cls.dequantize_group_gemm, (take57, model_layers_7_mlp_moe_gate_up_proj_q_weight4, model_layers_7_mlp_moe_gate_up_proj_q_scale4, lv863), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv67 = R.call_tir(cls.fused_split_silu_multiply, (lv864,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv865 = R.call_tir(cls.dequantize_group_gemm1, (lv67, model_layers_7_mlp_moe_down_proj_q_weight4, model_layers_7_mlp_moe_down_proj_q_scale4, lv863), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv866 = R.call_tir(cls.scatter_output, (lv865, get_indices_055), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape693 = R.call_tir(cls.reshape6, (top4_softmax_079,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape694 = R.call_tir(cls.reshape7, (lv866,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv68 = R.call_tir(cls.fused_multiply1_sum, (reshape694, reshape693), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv15_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_7_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_7_mlp_shared_expert_gate_up_proj_q_scale4, reshape691), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv69 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv15_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv70 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape691, model_layers_7_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv7_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_7_mlp_shared_expert_down_proj_q_weight4, model_layers_7_mlp_shared_expert_down_proj_q_scale4, lv69, lv70, lv68), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape695 = R.call_tir(cls.reshape8, (lv7_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape695, lv29_1, model_layers_8_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv31_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30_1[1]
            rms_norm163: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv30_1[0]
            lv8_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_8_self_attn_c_attn_q_weight4, model_layers_8_self_attn_c_attn_q_scale4, rms_norm163, model_layers_8_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape696 = R.call_tir(cls.reshape, (lv8_3,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape697 = R.call_tir(cls.reshape1, (reshape696,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv870 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape697), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape698 = R.call_tir(cls.reshape2, (lv870,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape699 = R.call_tir(cls.reshape3, (reshape698,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv16_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_8_self_attn_o_proj_q_weight4, model_layers_8_self_attn_o_proj_q_scale4, reshape699), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv32_1 = R.call_tir(cls.fuse_add_norm_decode, (lv16_2, lv31_1, model_layers_8_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv33_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32_1[1]
            rms_norm164: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv32_1[0]
            reshape700 = R.call_tir(cls.reshape4, (rms_norm164,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv73 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape700, model_layers_8_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv74 = R.call_tir(cls.fused_softmax_cast1, (lv73,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv872 = R.call_tir(cls.top4_softmax, (lv74,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_080: R.Tensor((batch_size, 4), dtype="float16") = lv872[0]
            top4_softmax_180: R.Tensor((batch_size, 4), dtype="int32") = lv872[1]
            lv75 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_180,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape701 = R.call_tir(cls.reshape5, (lv75,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv16_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape701, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv17_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv16_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum56: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv17_1, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv874 = R.call_tir(cls.get_indices, (cumsum56, top4_softmax_180), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_056: R.Tensor((batch_size * 4,), dtype="int32") = lv874[0]
            get_indices_156: R.Tensor((batch_size * 4,), dtype="int32") = lv874[1]
            lv875 = R.call_tir(cls.get_expert_instance_indptr, (cumsum56,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take58 = R.call_tir(cls.take, (reshape700, get_indices_156), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv876 = R.call_tir(cls.dequantize_group_gemm, (take58, model_layers_8_mlp_moe_gate_up_proj_q_weight4, model_layers_8_mlp_moe_gate_up_proj_q_scale4, lv875), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv76 = R.call_tir(cls.fused_split_silu_multiply, (lv876,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv877 = R.call_tir(cls.dequantize_group_gemm1, (lv76, model_layers_8_mlp_moe_down_proj_q_weight4, model_layers_8_mlp_moe_down_proj_q_scale4, lv875), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv878 = R.call_tir(cls.scatter_output, (lv877, get_indices_056), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape702 = R.call_tir(cls.reshape6, (top4_softmax_080,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape703 = R.call_tir(cls.reshape7, (lv878,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv77 = R.call_tir(cls.fused_multiply1_sum, (reshape703, reshape702), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv17_2 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_8_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_8_mlp_shared_expert_gate_up_proj_q_scale4, reshape700), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv78 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv17_2,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv79 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape700, model_layers_8_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv8_4 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_8_mlp_shared_expert_down_proj_q_weight4, model_layers_8_mlp_shared_expert_down_proj_q_scale4, lv78, lv79, lv77), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape704 = R.call_tir(cls.reshape8, (lv8_4,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv34_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape704, lv33_1, model_layers_9_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv35: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34_1[1]
            rms_norm165: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv34_1[0]
            lv9_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_9_self_attn_c_attn_q_weight4, model_layers_9_self_attn_c_attn_q_scale4, rms_norm165, model_layers_9_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape705 = R.call_tir(cls.reshape, (lv9_3,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape706 = R.call_tir(cls.reshape1, (reshape705,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv882 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape706), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape707 = R.call_tir(cls.reshape2, (lv882,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape708 = R.call_tir(cls.reshape3, (reshape707,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv18_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_9_self_attn_o_proj_q_weight4, model_layers_9_self_attn_o_proj_q_scale4, reshape708), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36 = R.call_tir(cls.fuse_add_norm_decode, (lv18_1, lv35, model_layers_9_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv37_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36[1]
            rms_norm166: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv36[0]
            reshape709 = R.call_tir(cls.reshape4, (rms_norm166,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv82 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape709, model_layers_9_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv83 = R.call_tir(cls.fused_softmax_cast1, (lv82,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv884 = R.call_tir(cls.top4_softmax, (lv83,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_081: R.Tensor((batch_size, 4), dtype="float16") = lv884[0]
            top4_softmax_181: R.Tensor((batch_size, 4), dtype="int32") = lv884[1]
            lv84 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_181,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape710 = R.call_tir(cls.reshape5, (lv84,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv18_2: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape710, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv19_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv18_2,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum57: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv19_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv886 = R.call_tir(cls.get_indices, (cumsum57, top4_softmax_181), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_057: R.Tensor((batch_size * 4,), dtype="int32") = lv886[0]
            get_indices_157: R.Tensor((batch_size * 4,), dtype="int32") = lv886[1]
            lv887 = R.call_tir(cls.get_expert_instance_indptr, (cumsum57,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take59 = R.call_tir(cls.take, (reshape709, get_indices_157), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv888 = R.call_tir(cls.dequantize_group_gemm, (take59, model_layers_9_mlp_moe_gate_up_proj_q_weight4, model_layers_9_mlp_moe_gate_up_proj_q_scale4, lv887), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv85 = R.call_tir(cls.fused_split_silu_multiply, (lv888,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv889 = R.call_tir(cls.dequantize_group_gemm1, (lv85, model_layers_9_mlp_moe_down_proj_q_weight4, model_layers_9_mlp_moe_down_proj_q_scale4, lv887), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv890 = R.call_tir(cls.scatter_output, (lv889, get_indices_057), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape711 = R.call_tir(cls.reshape6, (top4_softmax_081,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape712 = R.call_tir(cls.reshape7, (lv890,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv86 = R.call_tir(cls.fused_multiply1_sum, (reshape712, reshape711), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv19_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_9_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_9_mlp_shared_expert_gate_up_proj_q_scale4, reshape709), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv87 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv19_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv88 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape709, model_layers_9_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv9_4 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_9_mlp_shared_expert_down_proj_q_weight4, model_layers_9_mlp_shared_expert_down_proj_q_scale4, lv87, lv88, lv86), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape713 = R.call_tir(cls.reshape8, (lv9_4,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv38_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape713, lv37_1, model_layers_10_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv39_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38_1[1]
            rms_norm167: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv38_1[0]
            lv10_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_10_self_attn_c_attn_q_weight4, model_layers_10_self_attn_c_attn_q_scale4, rms_norm167, model_layers_10_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape714 = R.call_tir(cls.reshape, (lv10_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape715 = R.call_tir(cls.reshape1, (reshape714,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv894 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape715), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape716 = R.call_tir(cls.reshape2, (lv894,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape717 = R.call_tir(cls.reshape3, (reshape716,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv20_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_10_self_attn_o_proj_q_weight4, model_layers_10_self_attn_o_proj_q_scale4, reshape717), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv40_1 = R.call_tir(cls.fuse_add_norm_decode, (lv20_2, lv39_1, model_layers_10_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv41_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40_1[1]
            rms_norm168: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv40_1[0]
            reshape718 = R.call_tir(cls.reshape4, (rms_norm168,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv91 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape718, model_layers_10_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv92 = R.call_tir(cls.fused_softmax_cast1, (lv91,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv896 = R.call_tir(cls.top4_softmax, (lv92,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_082: R.Tensor((batch_size, 4), dtype="float16") = lv896[0]
            top4_softmax_182: R.Tensor((batch_size, 4), dtype="int32") = lv896[1]
            lv93 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_182,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape719 = R.call_tir(cls.reshape5, (lv93,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv20_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape719, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv21_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv20_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum58: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv21_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv898 = R.call_tir(cls.get_indices, (cumsum58, top4_softmax_182), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_058: R.Tensor((batch_size * 4,), dtype="int32") = lv898[0]
            get_indices_158: R.Tensor((batch_size * 4,), dtype="int32") = lv898[1]
            lv899 = R.call_tir(cls.get_expert_instance_indptr, (cumsum58,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take60 = R.call_tir(cls.take, (reshape718, get_indices_158), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv900 = R.call_tir(cls.dequantize_group_gemm, (take60, model_layers_10_mlp_moe_gate_up_proj_q_weight4, model_layers_10_mlp_moe_gate_up_proj_q_scale4, lv899), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv94 = R.call_tir(cls.fused_split_silu_multiply, (lv900,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv901 = R.call_tir(cls.dequantize_group_gemm1, (lv94, model_layers_10_mlp_moe_down_proj_q_weight4, model_layers_10_mlp_moe_down_proj_q_scale4, lv899), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv902 = R.call_tir(cls.scatter_output, (lv901, get_indices_058), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape720 = R.call_tir(cls.reshape6, (top4_softmax_082,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape721 = R.call_tir(cls.reshape7, (lv902,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv95 = R.call_tir(cls.fused_multiply1_sum, (reshape721, reshape720), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv21_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_10_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_10_mlp_shared_expert_gate_up_proj_q_scale4, reshape718), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv96 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv21_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv97 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape718, model_layers_10_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv10_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_10_mlp_shared_expert_down_proj_q_weight4, model_layers_10_mlp_shared_expert_down_proj_q_scale4, lv96, lv97, lv95), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape722 = R.call_tir(cls.reshape8, (lv10_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape722, lv41_1, model_layers_11_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv43_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42_1[1]
            rms_norm169: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv42_1[0]
            lv11_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_11_self_attn_c_attn_q_weight4, model_layers_11_self_attn_c_attn_q_scale4, rms_norm169, model_layers_11_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape723 = R.call_tir(cls.reshape, (lv11_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape724 = R.call_tir(cls.reshape1, (reshape723,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv906 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape724), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape725 = R.call_tir(cls.reshape2, (lv906,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape726 = R.call_tir(cls.reshape3, (reshape725,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv22_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_11_self_attn_o_proj_q_weight4, model_layers_11_self_attn_o_proj_q_scale4, reshape726), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv44 = R.call_tir(cls.fuse_add_norm_decode, (lv22_2, lv43_1, model_layers_11_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv45: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44[1]
            rms_norm170: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv44[0]
            reshape727 = R.call_tir(cls.reshape4, (rms_norm170,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv100 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape727, model_layers_11_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv101 = R.call_tir(cls.fused_softmax_cast1, (lv100,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv908 = R.call_tir(cls.top4_softmax, (lv101,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_083: R.Tensor((batch_size, 4), dtype="float16") = lv908[0]
            top4_softmax_183: R.Tensor((batch_size, 4), dtype="int32") = lv908[1]
            lv102 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_183,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape728 = R.call_tir(cls.reshape5, (lv102,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv22_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape728, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv23_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv22_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum59: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv23_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv910 = R.call_tir(cls.get_indices, (cumsum59, top4_softmax_183), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_059: R.Tensor((batch_size * 4,), dtype="int32") = lv910[0]
            get_indices_159: R.Tensor((batch_size * 4,), dtype="int32") = lv910[1]
            lv911 = R.call_tir(cls.get_expert_instance_indptr, (cumsum59,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take61 = R.call_tir(cls.take, (reshape727, get_indices_159), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv912 = R.call_tir(cls.dequantize_group_gemm, (take61, model_layers_11_mlp_moe_gate_up_proj_q_weight4, model_layers_11_mlp_moe_gate_up_proj_q_scale4, lv911), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv103 = R.call_tir(cls.fused_split_silu_multiply, (lv912,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv913 = R.call_tir(cls.dequantize_group_gemm1, (lv103, model_layers_11_mlp_moe_down_proj_q_weight4, model_layers_11_mlp_moe_down_proj_q_scale4, lv911), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv914 = R.call_tir(cls.scatter_output, (lv913, get_indices_059), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape729 = R.call_tir(cls.reshape6, (top4_softmax_083,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape730 = R.call_tir(cls.reshape7, (lv914,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv104 = R.call_tir(cls.fused_multiply1_sum, (reshape730, reshape729), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv23_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_11_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_11_mlp_shared_expert_gate_up_proj_q_scale4, reshape727), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv105 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv23_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv106 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape727, model_layers_11_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv11_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_11_mlp_shared_expert_down_proj_q_weight4, model_layers_11_mlp_shared_expert_down_proj_q_scale4, lv105, lv106, lv104), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape731 = R.call_tir(cls.reshape8, (lv11_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv46_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape731, lv45, model_layers_12_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv47_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46_1[1]
            rms_norm171: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv46_1[0]
            lv12_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_12_self_attn_c_attn_q_weight4, model_layers_12_self_attn_c_attn_q_scale4, rms_norm171, model_layers_12_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape732 = R.call_tir(cls.reshape, (lv12_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape733 = R.call_tir(cls.reshape1, (reshape732,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv918 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape733), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape734 = R.call_tir(cls.reshape2, (lv918,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape735 = R.call_tir(cls.reshape3, (reshape734,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv24_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_12_self_attn_o_proj_q_weight4, model_layers_12_self_attn_o_proj_q_scale4, reshape735), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv48_1 = R.call_tir(cls.fuse_add_norm_decode, (lv24_2, lv47_1, model_layers_12_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv49_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48_1[1]
            rms_norm172: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv48_1[0]
            reshape736 = R.call_tir(cls.reshape4, (rms_norm172,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv109 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape736, model_layers_12_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv110 = R.call_tir(cls.fused_softmax_cast1, (lv109,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv920 = R.call_tir(cls.top4_softmax, (lv110,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_084: R.Tensor((batch_size, 4), dtype="float16") = lv920[0]
            top4_softmax_184: R.Tensor((batch_size, 4), dtype="int32") = lv920[1]
            lv111 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_184,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape737 = R.call_tir(cls.reshape5, (lv111,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv24_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape737, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv25_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv24_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum60: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv25_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv922 = R.call_tir(cls.get_indices, (cumsum60, top4_softmax_184), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_060: R.Tensor((batch_size * 4,), dtype="int32") = lv922[0]
            get_indices_160: R.Tensor((batch_size * 4,), dtype="int32") = lv922[1]
            lv923 = R.call_tir(cls.get_expert_instance_indptr, (cumsum60,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take62 = R.call_tir(cls.take, (reshape736, get_indices_160), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv924 = R.call_tir(cls.dequantize_group_gemm, (take62, model_layers_12_mlp_moe_gate_up_proj_q_weight4, model_layers_12_mlp_moe_gate_up_proj_q_scale4, lv923), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv112 = R.call_tir(cls.fused_split_silu_multiply, (lv924,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv925 = R.call_tir(cls.dequantize_group_gemm1, (lv112, model_layers_12_mlp_moe_down_proj_q_weight4, model_layers_12_mlp_moe_down_proj_q_scale4, lv923), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv926 = R.call_tir(cls.scatter_output, (lv925, get_indices_060), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape738 = R.call_tir(cls.reshape6, (top4_softmax_084,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape739 = R.call_tir(cls.reshape7, (lv926,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv113 = R.call_tir(cls.fused_multiply1_sum, (reshape739, reshape738), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv25_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_12_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_12_mlp_shared_expert_gate_up_proj_q_scale4, reshape736), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv114 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv25_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv115 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape736, model_layers_12_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv12_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_12_mlp_shared_expert_down_proj_q_weight4, model_layers_12_mlp_shared_expert_down_proj_q_scale4, lv114, lv115, lv113), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape740 = R.call_tir(cls.reshape8, (lv12_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv50_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape740, lv49_1, model_layers_13_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv51_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50_1[1]
            rms_norm173: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv50_1[0]
            lv13_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_13_self_attn_c_attn_q_weight4, model_layers_13_self_attn_c_attn_q_scale4, rms_norm173, model_layers_13_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape741 = R.call_tir(cls.reshape, (lv13_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape742 = R.call_tir(cls.reshape1, (reshape741,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv930 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape742), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape743 = R.call_tir(cls.reshape2, (lv930,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape744 = R.call_tir(cls.reshape3, (reshape743,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv26_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_13_self_attn_o_proj_q_weight4, model_layers_13_self_attn_o_proj_q_scale4, reshape744), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv52_1 = R.call_tir(cls.fuse_add_norm_decode, (lv26_1, lv51_1, model_layers_13_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv53: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52_1[1]
            rms_norm174: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv52_1[0]
            reshape745 = R.call_tir(cls.reshape4, (rms_norm174,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv118 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape745, model_layers_13_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv119 = R.call_tir(cls.fused_softmax_cast1, (lv118,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv932 = R.call_tir(cls.top4_softmax, (lv119,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_085: R.Tensor((batch_size, 4), dtype="float16") = lv932[0]
            top4_softmax_185: R.Tensor((batch_size, 4), dtype="int32") = lv932[1]
            lv120 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_185,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape746 = R.call_tir(cls.reshape5, (lv120,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv26_2: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape746, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv27_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv26_2,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum61: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv27_1, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv934 = R.call_tir(cls.get_indices, (cumsum61, top4_softmax_185), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_061: R.Tensor((batch_size * 4,), dtype="int32") = lv934[0]
            get_indices_161: R.Tensor((batch_size * 4,), dtype="int32") = lv934[1]
            lv935 = R.call_tir(cls.get_expert_instance_indptr, (cumsum61,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take63 = R.call_tir(cls.take, (reshape745, get_indices_161), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv936 = R.call_tir(cls.dequantize_group_gemm, (take63, model_layers_13_mlp_moe_gate_up_proj_q_weight4, model_layers_13_mlp_moe_gate_up_proj_q_scale4, lv935), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv121 = R.call_tir(cls.fused_split_silu_multiply, (lv936,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv937 = R.call_tir(cls.dequantize_group_gemm1, (lv121, model_layers_13_mlp_moe_down_proj_q_weight4, model_layers_13_mlp_moe_down_proj_q_scale4, lv935), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv938 = R.call_tir(cls.scatter_output, (lv937, get_indices_061), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape747 = R.call_tir(cls.reshape6, (top4_softmax_085,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape748 = R.call_tir(cls.reshape7, (lv938,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv122 = R.call_tir(cls.fused_multiply1_sum, (reshape748, reshape747), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv27_2 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_13_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_13_mlp_shared_expert_gate_up_proj_q_scale4, reshape745), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv123 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv27_2,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv124 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape745, model_layers_13_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv13_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_13_mlp_shared_expert_down_proj_q_weight4, model_layers_13_mlp_shared_expert_down_proj_q_scale4, lv123, lv124, lv122), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape749 = R.call_tir(cls.reshape8, (lv13_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv54 = R.call_tir(cls.fuse_add_norm_decode, (reshape749, lv53, model_layers_14_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv55_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54[1]
            rms_norm175: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv54[0]
            lv14_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_14_self_attn_c_attn_q_weight4, model_layers_14_self_attn_c_attn_q_scale4, rms_norm175, model_layers_14_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape750 = R.call_tir(cls.reshape, (lv14_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape751 = R.call_tir(cls.reshape1, (reshape750,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv942 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape751), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape752 = R.call_tir(cls.reshape2, (lv942,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape753 = R.call_tir(cls.reshape3, (reshape752,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv28_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_14_self_attn_o_proj_q_weight4, model_layers_14_self_attn_o_proj_q_scale4, reshape753), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv56_1 = R.call_tir(cls.fuse_add_norm_decode, (lv28_2, lv55_1, model_layers_14_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv57_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56_1[1]
            rms_norm176: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv56_1[0]
            reshape754 = R.call_tir(cls.reshape4, (rms_norm176,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv127 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape754, model_layers_14_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv128 = R.call_tir(cls.fused_softmax_cast1, (lv127,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv944 = R.call_tir(cls.top4_softmax, (lv128,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_086: R.Tensor((batch_size, 4), dtype="float16") = lv944[0]
            top4_softmax_186: R.Tensor((batch_size, 4), dtype="int32") = lv944[1]
            lv129 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_186,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape755 = R.call_tir(cls.reshape5, (lv129,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv28_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape755, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv29_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv28_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum62: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv29_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv946 = R.call_tir(cls.get_indices, (cumsum62, top4_softmax_186), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_062: R.Tensor((batch_size * 4,), dtype="int32") = lv946[0]
            get_indices_162: R.Tensor((batch_size * 4,), dtype="int32") = lv946[1]
            lv947 = R.call_tir(cls.get_expert_instance_indptr, (cumsum62,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take64 = R.call_tir(cls.take, (reshape754, get_indices_162), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv948 = R.call_tir(cls.dequantize_group_gemm, (take64, model_layers_14_mlp_moe_gate_up_proj_q_weight4, model_layers_14_mlp_moe_gate_up_proj_q_scale4, lv947), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv130 = R.call_tir(cls.fused_split_silu_multiply, (lv948,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv949 = R.call_tir(cls.dequantize_group_gemm1, (lv130, model_layers_14_mlp_moe_down_proj_q_weight4, model_layers_14_mlp_moe_down_proj_q_scale4, lv947), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv950 = R.call_tir(cls.scatter_output, (lv949, get_indices_062), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape756 = R.call_tir(cls.reshape6, (top4_softmax_086,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape757 = R.call_tir(cls.reshape7, (lv950,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv131 = R.call_tir(cls.fused_multiply1_sum, (reshape757, reshape756), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv29_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_14_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_14_mlp_shared_expert_gate_up_proj_q_scale4, reshape754), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv132 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv29_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv133 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape754, model_layers_14_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv14_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_14_mlp_shared_expert_down_proj_q_weight4, model_layers_14_mlp_shared_expert_down_proj_q_scale4, lv132, lv133, lv131), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape758 = R.call_tir(cls.reshape8, (lv14_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv58_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape758, lv57_1, model_layers_15_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv59_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58_1[1]
            rms_norm177: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv58_1[0]
            lv15_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_15_self_attn_c_attn_q_weight4, model_layers_15_self_attn_c_attn_q_scale4, rms_norm177, model_layers_15_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape759 = R.call_tir(cls.reshape, (lv15_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape760 = R.call_tir(cls.reshape1, (reshape759,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv954 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape760), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape761 = R.call_tir(cls.reshape2, (lv954,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape762 = R.call_tir(cls.reshape3, (reshape761,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv30_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_15_self_attn_o_proj_q_weight4, model_layers_15_self_attn_o_proj_q_scale4, reshape762), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv60_1 = R.call_tir(cls.fuse_add_norm_decode, (lv30_2, lv59_1, model_layers_15_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv61_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60_1[1]
            rms_norm178: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv60_1[0]
            reshape763 = R.call_tir(cls.reshape4, (rms_norm178,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv136 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape763, model_layers_15_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv137 = R.call_tir(cls.fused_softmax_cast1, (lv136,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv956 = R.call_tir(cls.top4_softmax, (lv137,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_087: R.Tensor((batch_size, 4), dtype="float16") = lv956[0]
            top4_softmax_187: R.Tensor((batch_size, 4), dtype="int32") = lv956[1]
            lv138 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_187,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape764 = R.call_tir(cls.reshape5, (lv138,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv30_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape764, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv31_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv30_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum63: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv31_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv958 = R.call_tir(cls.get_indices, (cumsum63, top4_softmax_187), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_063: R.Tensor((batch_size * 4,), dtype="int32") = lv958[0]
            get_indices_163: R.Tensor((batch_size * 4,), dtype="int32") = lv958[1]
            lv959 = R.call_tir(cls.get_expert_instance_indptr, (cumsum63,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take65 = R.call_tir(cls.take, (reshape763, get_indices_163), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv960 = R.call_tir(cls.dequantize_group_gemm, (take65, model_layers_15_mlp_moe_gate_up_proj_q_weight4, model_layers_15_mlp_moe_gate_up_proj_q_scale4, lv959), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv139 = R.call_tir(cls.fused_split_silu_multiply, (lv960,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv961 = R.call_tir(cls.dequantize_group_gemm1, (lv139, model_layers_15_mlp_moe_down_proj_q_weight4, model_layers_15_mlp_moe_down_proj_q_scale4, lv959), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv962 = R.call_tir(cls.scatter_output, (lv961, get_indices_063), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape765 = R.call_tir(cls.reshape6, (top4_softmax_087,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape766 = R.call_tir(cls.reshape7, (lv962,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv140 = R.call_tir(cls.fused_multiply1_sum, (reshape766, reshape765), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv31_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_15_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_15_mlp_shared_expert_gate_up_proj_q_scale4, reshape763), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv141 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv31_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv142 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape763, model_layers_15_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv15_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_15_mlp_shared_expert_down_proj_q_weight4, model_layers_15_mlp_shared_expert_down_proj_q_scale4, lv141, lv142, lv140), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape767 = R.call_tir(cls.reshape8, (lv15_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv62 = R.call_tir(cls.fuse_add_norm_decode, (reshape767, lv61_1, model_layers_16_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv63: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62[1]
            rms_norm179: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv62[0]
            lv16_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_16_self_attn_c_attn_q_weight4, model_layers_16_self_attn_c_attn_q_scale4, rms_norm179, model_layers_16_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape768 = R.call_tir(cls.reshape, (lv16_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape769 = R.call_tir(cls.reshape1, (reshape768,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv966 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape769), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape770 = R.call_tir(cls.reshape2, (lv966,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape771 = R.call_tir(cls.reshape3, (reshape770,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv32_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_16_self_attn_o_proj_q_weight4, model_layers_16_self_attn_o_proj_q_scale4, reshape771), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv64_1 = R.call_tir(cls.fuse_add_norm_decode, (lv32_2, lv63, model_layers_16_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv65_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64_1[1]
            rms_norm180: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv64_1[0]
            reshape772 = R.call_tir(cls.reshape4, (rms_norm180,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv145 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape772, model_layers_16_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv146 = R.call_tir(cls.fused_softmax_cast1, (lv145,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv968 = R.call_tir(cls.top4_softmax, (lv146,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_088: R.Tensor((batch_size, 4), dtype="float16") = lv968[0]
            top4_softmax_188: R.Tensor((batch_size, 4), dtype="int32") = lv968[1]
            lv147 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_188,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape773 = R.call_tir(cls.reshape5, (lv147,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv32_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape773, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv33_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv32_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum64: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv33_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv970 = R.call_tir(cls.get_indices, (cumsum64, top4_softmax_188), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_064: R.Tensor((batch_size * 4,), dtype="int32") = lv970[0]
            get_indices_164: R.Tensor((batch_size * 4,), dtype="int32") = lv970[1]
            lv971 = R.call_tir(cls.get_expert_instance_indptr, (cumsum64,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take66 = R.call_tir(cls.take, (reshape772, get_indices_164), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv972 = R.call_tir(cls.dequantize_group_gemm, (take66, model_layers_16_mlp_moe_gate_up_proj_q_weight4, model_layers_16_mlp_moe_gate_up_proj_q_scale4, lv971), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv148 = R.call_tir(cls.fused_split_silu_multiply, (lv972,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv973 = R.call_tir(cls.dequantize_group_gemm1, (lv148, model_layers_16_mlp_moe_down_proj_q_weight4, model_layers_16_mlp_moe_down_proj_q_scale4, lv971), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv974 = R.call_tir(cls.scatter_output, (lv973, get_indices_064), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape774 = R.call_tir(cls.reshape6, (top4_softmax_088,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape775 = R.call_tir(cls.reshape7, (lv974,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv149 = R.call_tir(cls.fused_multiply1_sum, (reshape775, reshape774), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv33_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_16_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_16_mlp_shared_expert_gate_up_proj_q_scale4, reshape772), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv150 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv33_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv151 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape772, model_layers_16_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv16_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_16_mlp_shared_expert_down_proj_q_weight4, model_layers_16_mlp_shared_expert_down_proj_q_scale4, lv150, lv151, lv149), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape776 = R.call_tir(cls.reshape8, (lv16_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv66_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape776, lv65_1, model_layers_17_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv67_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66_1[1]
            rms_norm181: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv66_1[0]
            lv17_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_17_self_attn_c_attn_q_weight4, model_layers_17_self_attn_c_attn_q_scale4, rms_norm181, model_layers_17_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape777 = R.call_tir(cls.reshape, (lv17_3,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape778 = R.call_tir(cls.reshape1, (reshape777,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv978 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape778), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape779 = R.call_tir(cls.reshape2, (lv978,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape780 = R.call_tir(cls.reshape3, (reshape779,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv34_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_17_self_attn_o_proj_q_weight4, model_layers_17_self_attn_o_proj_q_scale4, reshape780), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv68_1 = R.call_tir(cls.fuse_add_norm_decode, (lv34_2, lv67_1, model_layers_17_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv69_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68_1[1]
            rms_norm182: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv68_1[0]
            reshape781 = R.call_tir(cls.reshape4, (rms_norm182,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv154 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape781, model_layers_17_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv155 = R.call_tir(cls.fused_softmax_cast1, (lv154,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv980 = R.call_tir(cls.top4_softmax, (lv155,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_089: R.Tensor((batch_size, 4), dtype="float16") = lv980[0]
            top4_softmax_189: R.Tensor((batch_size, 4), dtype="int32") = lv980[1]
            lv156 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_189,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape782 = R.call_tir(cls.reshape5, (lv156,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv34_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape782, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv35_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv34_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum65: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv35_1, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv982 = R.call_tir(cls.get_indices, (cumsum65, top4_softmax_189), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_065: R.Tensor((batch_size * 4,), dtype="int32") = lv982[0]
            get_indices_165: R.Tensor((batch_size * 4,), dtype="int32") = lv982[1]
            lv983 = R.call_tir(cls.get_expert_instance_indptr, (cumsum65,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take67 = R.call_tir(cls.take, (reshape781, get_indices_165), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv984 = R.call_tir(cls.dequantize_group_gemm, (take67, model_layers_17_mlp_moe_gate_up_proj_q_weight4, model_layers_17_mlp_moe_gate_up_proj_q_scale4, lv983), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv157 = R.call_tir(cls.fused_split_silu_multiply, (lv984,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv985 = R.call_tir(cls.dequantize_group_gemm1, (lv157, model_layers_17_mlp_moe_down_proj_q_weight4, model_layers_17_mlp_moe_down_proj_q_scale4, lv983), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv986 = R.call_tir(cls.scatter_output, (lv985, get_indices_065), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape783 = R.call_tir(cls.reshape6, (top4_softmax_089,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape784 = R.call_tir(cls.reshape7, (lv986,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv158 = R.call_tir(cls.fused_multiply1_sum, (reshape784, reshape783), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv35_2 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_17_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_17_mlp_shared_expert_gate_up_proj_q_scale4, reshape781), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv159 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv35_2,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv160 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape781, model_layers_17_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv17_4 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_17_mlp_shared_expert_down_proj_q_weight4, model_layers_17_mlp_shared_expert_down_proj_q_scale4, lv159, lv160, lv158), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape785 = R.call_tir(cls.reshape8, (lv17_4,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv70_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape785, lv69_1, model_layers_18_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv71: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70_1[1]
            rms_norm183: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv70_1[0]
            lv18_3 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_18_self_attn_c_attn_q_weight4, model_layers_18_self_attn_c_attn_q_scale4, rms_norm183, model_layers_18_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape786 = R.call_tir(cls.reshape, (lv18_3,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape787 = R.call_tir(cls.reshape1, (reshape786,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv990 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape787), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape788 = R.call_tir(cls.reshape2, (lv990,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape789 = R.call_tir(cls.reshape3, (reshape788,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv36_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_18_self_attn_o_proj_q_weight4, model_layers_18_self_attn_o_proj_q_scale4, reshape789), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv72 = R.call_tir(cls.fuse_add_norm_decode, (lv36_1, lv71, model_layers_18_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv73_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72[1]
            rms_norm184: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv72[0]
            reshape790 = R.call_tir(cls.reshape4, (rms_norm184,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv163 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape790, model_layers_18_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv164 = R.call_tir(cls.fused_softmax_cast1, (lv163,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv992 = R.call_tir(cls.top4_softmax, (lv164,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_090: R.Tensor((batch_size, 4), dtype="float16") = lv992[0]
            top4_softmax_190: R.Tensor((batch_size, 4), dtype="int32") = lv992[1]
            lv165 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_190,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape791 = R.call_tir(cls.reshape5, (lv165,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv36_2: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape791, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv37_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv36_2,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum66: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv37_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv994 = R.call_tir(cls.get_indices, (cumsum66, top4_softmax_190), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_066: R.Tensor((batch_size * 4,), dtype="int32") = lv994[0]
            get_indices_166: R.Tensor((batch_size * 4,), dtype="int32") = lv994[1]
            lv995 = R.call_tir(cls.get_expert_instance_indptr, (cumsum66,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take68 = R.call_tir(cls.take, (reshape790, get_indices_166), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv996 = R.call_tir(cls.dequantize_group_gemm, (take68, model_layers_18_mlp_moe_gate_up_proj_q_weight4, model_layers_18_mlp_moe_gate_up_proj_q_scale4, lv995), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv166 = R.call_tir(cls.fused_split_silu_multiply, (lv996,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv997 = R.call_tir(cls.dequantize_group_gemm1, (lv166, model_layers_18_mlp_moe_down_proj_q_weight4, model_layers_18_mlp_moe_down_proj_q_scale4, lv995), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv998 = R.call_tir(cls.scatter_output, (lv997, get_indices_066), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape792 = R.call_tir(cls.reshape6, (top4_softmax_090,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape793 = R.call_tir(cls.reshape7, (lv998,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv167 = R.call_tir(cls.fused_multiply1_sum, (reshape793, reshape792), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv37_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_18_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_18_mlp_shared_expert_gate_up_proj_q_scale4, reshape790), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv168 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv37_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv169 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape790, model_layers_18_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv18_4 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_18_mlp_shared_expert_down_proj_q_weight4, model_layers_18_mlp_shared_expert_down_proj_q_scale4, lv168, lv169, lv167), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape794 = R.call_tir(cls.reshape8, (lv18_4,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv74_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape794, lv73_1, model_layers_19_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv75_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74_1[1]
            rms_norm185: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv74_1[0]
            lv19_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_19_self_attn_c_attn_q_weight4, model_layers_19_self_attn_c_attn_q_scale4, rms_norm185, model_layers_19_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape795 = R.call_tir(cls.reshape, (lv19_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape796 = R.call_tir(cls.reshape1, (reshape795,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv1002 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape796), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape797 = R.call_tir(cls.reshape2, (lv1002,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape798 = R.call_tir(cls.reshape3, (reshape797,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv38_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_19_self_attn_o_proj_q_weight4, model_layers_19_self_attn_o_proj_q_scale4, reshape798), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv76_1 = R.call_tir(cls.fuse_add_norm_decode, (lv38_2, lv75_1, model_layers_19_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv77_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76_1[1]
            rms_norm186: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv76_1[0]
            reshape799 = R.call_tir(cls.reshape4, (rms_norm186,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv172 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape799, model_layers_19_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv173 = R.call_tir(cls.fused_softmax_cast1, (lv172,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv1004 = R.call_tir(cls.top4_softmax, (lv173,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_091: R.Tensor((batch_size, 4), dtype="float16") = lv1004[0]
            top4_softmax_191: R.Tensor((batch_size, 4), dtype="int32") = lv1004[1]
            lv174 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_191,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape800 = R.call_tir(cls.reshape5, (lv174,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv38_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape800, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv39_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv38_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum67: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv39_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv1006 = R.call_tir(cls.get_indices, (cumsum67, top4_softmax_191), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_067: R.Tensor((batch_size * 4,), dtype="int32") = lv1006[0]
            get_indices_167: R.Tensor((batch_size * 4,), dtype="int32") = lv1006[1]
            lv1007 = R.call_tir(cls.get_expert_instance_indptr, (cumsum67,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take69 = R.call_tir(cls.take, (reshape799, get_indices_167), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1008 = R.call_tir(cls.dequantize_group_gemm, (take69, model_layers_19_mlp_moe_gate_up_proj_q_weight4, model_layers_19_mlp_moe_gate_up_proj_q_scale4, lv1007), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv175 = R.call_tir(cls.fused_split_silu_multiply, (lv1008,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv1009 = R.call_tir(cls.dequantize_group_gemm1, (lv175, model_layers_19_mlp_moe_down_proj_q_weight4, model_layers_19_mlp_moe_down_proj_q_scale4, lv1007), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1010 = R.call_tir(cls.scatter_output, (lv1009, get_indices_067), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape801 = R.call_tir(cls.reshape6, (top4_softmax_091,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape802 = R.call_tir(cls.reshape7, (lv1010,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv176 = R.call_tir(cls.fused_multiply1_sum, (reshape802, reshape801), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv39_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_19_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_19_mlp_shared_expert_gate_up_proj_q_scale4, reshape799), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv177 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv39_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv178 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape799, model_layers_19_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv19_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_19_mlp_shared_expert_down_proj_q_weight4, model_layers_19_mlp_shared_expert_down_proj_q_scale4, lv177, lv178, lv176), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape803 = R.call_tir(cls.reshape8, (lv19_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv78_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape803, lv77_1, model_layers_20_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv79_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78_1[1]
            rms_norm187: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv78_1[0]
            lv20_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_20_self_attn_c_attn_q_weight4, model_layers_20_self_attn_c_attn_q_scale4, rms_norm187, model_layers_20_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape804 = R.call_tir(cls.reshape, (lv20_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape805 = R.call_tir(cls.reshape1, (reshape804,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv1014 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape805), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape806 = R.call_tir(cls.reshape2, (lv1014,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape807 = R.call_tir(cls.reshape3, (reshape806,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv40_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_20_self_attn_o_proj_q_weight4, model_layers_20_self_attn_o_proj_q_scale4, reshape807), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv80 = R.call_tir(cls.fuse_add_norm_decode, (lv40_2, lv79_1, model_layers_20_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv81: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80[1]
            rms_norm188: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv80[0]
            reshape808 = R.call_tir(cls.reshape4, (rms_norm188,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv181 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape808, model_layers_20_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv182 = R.call_tir(cls.fused_softmax_cast1, (lv181,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv1016 = R.call_tir(cls.top4_softmax, (lv182,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_092: R.Tensor((batch_size, 4), dtype="float16") = lv1016[0]
            top4_softmax_192: R.Tensor((batch_size, 4), dtype="int32") = lv1016[1]
            lv183 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_192,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape809 = R.call_tir(cls.reshape5, (lv183,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv40_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape809, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv41_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv40_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum68: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv41_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv1018 = R.call_tir(cls.get_indices, (cumsum68, top4_softmax_192), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_068: R.Tensor((batch_size * 4,), dtype="int32") = lv1018[0]
            get_indices_168: R.Tensor((batch_size * 4,), dtype="int32") = lv1018[1]
            lv1019 = R.call_tir(cls.get_expert_instance_indptr, (cumsum68,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take70 = R.call_tir(cls.take, (reshape808, get_indices_168), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1020 = R.call_tir(cls.dequantize_group_gemm, (take70, model_layers_20_mlp_moe_gate_up_proj_q_weight4, model_layers_20_mlp_moe_gate_up_proj_q_scale4, lv1019), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv184 = R.call_tir(cls.fused_split_silu_multiply, (lv1020,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv1021 = R.call_tir(cls.dequantize_group_gemm1, (lv184, model_layers_20_mlp_moe_down_proj_q_weight4, model_layers_20_mlp_moe_down_proj_q_scale4, lv1019), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1022 = R.call_tir(cls.scatter_output, (lv1021, get_indices_068), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape810 = R.call_tir(cls.reshape6, (top4_softmax_092,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape811 = R.call_tir(cls.reshape7, (lv1022,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv185 = R.call_tir(cls.fused_multiply1_sum, (reshape811, reshape810), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv41_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_20_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_20_mlp_shared_expert_gate_up_proj_q_scale4, reshape808), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv186 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv41_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv187 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape808, model_layers_20_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv20_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_20_mlp_shared_expert_down_proj_q_weight4, model_layers_20_mlp_shared_expert_down_proj_q_scale4, lv186, lv187, lv185), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape812 = R.call_tir(cls.reshape8, (lv20_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv82_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape812, lv81, model_layers_21_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv83_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82_1[1]
            rms_norm189: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv82_1[0]
            lv21_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_21_self_attn_c_attn_q_weight4, model_layers_21_self_attn_c_attn_q_scale4, rms_norm189, model_layers_21_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape813 = R.call_tir(cls.reshape, (lv21_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape814 = R.call_tir(cls.reshape1, (reshape813,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv1026 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape814), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape815 = R.call_tir(cls.reshape2, (lv1026,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape816 = R.call_tir(cls.reshape3, (reshape815,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv42_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_21_self_attn_o_proj_q_weight4, model_layers_21_self_attn_o_proj_q_scale4, reshape816), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv84_1 = R.call_tir(cls.fuse_add_norm_decode, (lv42_2, lv83_1, model_layers_21_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv85_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84_1[1]
            rms_norm190: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv84_1[0]
            reshape817 = R.call_tir(cls.reshape4, (rms_norm190,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv190 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape817, model_layers_21_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv191 = R.call_tir(cls.fused_softmax_cast1, (lv190,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv1028 = R.call_tir(cls.top4_softmax, (lv191,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_093: R.Tensor((batch_size, 4), dtype="float16") = lv1028[0]
            top4_softmax_193: R.Tensor((batch_size, 4), dtype="int32") = lv1028[1]
            lv192 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_193,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape818 = R.call_tir(cls.reshape5, (lv192,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv42_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape818, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv43_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv42_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum69: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv43_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv1030 = R.call_tir(cls.get_indices, (cumsum69, top4_softmax_193), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_069: R.Tensor((batch_size * 4,), dtype="int32") = lv1030[0]
            get_indices_169: R.Tensor((batch_size * 4,), dtype="int32") = lv1030[1]
            lv1031 = R.call_tir(cls.get_expert_instance_indptr, (cumsum69,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take71 = R.call_tir(cls.take, (reshape817, get_indices_169), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1032 = R.call_tir(cls.dequantize_group_gemm, (take71, model_layers_21_mlp_moe_gate_up_proj_q_weight4, model_layers_21_mlp_moe_gate_up_proj_q_scale4, lv1031), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv193 = R.call_tir(cls.fused_split_silu_multiply, (lv1032,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv1033 = R.call_tir(cls.dequantize_group_gemm1, (lv193, model_layers_21_mlp_moe_down_proj_q_weight4, model_layers_21_mlp_moe_down_proj_q_scale4, lv1031), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1034 = R.call_tir(cls.scatter_output, (lv1033, get_indices_069), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape819 = R.call_tir(cls.reshape6, (top4_softmax_093,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape820 = R.call_tir(cls.reshape7, (lv1034,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv194 = R.call_tir(cls.fused_multiply1_sum, (reshape820, reshape819), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv43_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_21_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_21_mlp_shared_expert_gate_up_proj_q_scale4, reshape817), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv195 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv43_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv196 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape817, model_layers_21_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv21_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_21_mlp_shared_expert_down_proj_q_weight4, model_layers_21_mlp_shared_expert_down_proj_q_scale4, lv195, lv196, lv194), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape821 = R.call_tir(cls.reshape8, (lv21_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv86_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape821, lv85_1, model_layers_22_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv87_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86_1[1]
            rms_norm191: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv86_1[0]
            lv22_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_22_self_attn_c_attn_q_weight4, model_layers_22_self_attn_c_attn_q_scale4, rms_norm191, model_layers_22_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape822 = R.call_tir(cls.reshape, (lv22_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape823 = R.call_tir(cls.reshape1, (reshape822,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv1038 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape823), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape824 = R.call_tir(cls.reshape2, (lv1038,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape825 = R.call_tir(cls.reshape3, (reshape824,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv44_1 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_22_self_attn_o_proj_q_weight4, model_layers_22_self_attn_o_proj_q_scale4, reshape825), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv88_1 = R.call_tir(cls.fuse_add_norm_decode, (lv44_1, lv87_1, model_layers_22_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv89: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88_1[1]
            rms_norm192: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv88_1[0]
            reshape826 = R.call_tir(cls.reshape4, (rms_norm192,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv199 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape826, model_layers_22_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv200 = R.call_tir(cls.fused_softmax_cast1, (lv199,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv1040 = R.call_tir(cls.top4_softmax, (lv200,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_094: R.Tensor((batch_size, 4), dtype="float16") = lv1040[0]
            top4_softmax_194: R.Tensor((batch_size, 4), dtype="int32") = lv1040[1]
            lv201 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_194,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape827 = R.call_tir(cls.reshape5, (lv201,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv44_2: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape827, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv45_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv44_2,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum70: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv45_1, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv1042 = R.call_tir(cls.get_indices, (cumsum70, top4_softmax_194), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_070: R.Tensor((batch_size * 4,), dtype="int32") = lv1042[0]
            get_indices_170: R.Tensor((batch_size * 4,), dtype="int32") = lv1042[1]
            lv1043 = R.call_tir(cls.get_expert_instance_indptr, (cumsum70,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take72 = R.call_tir(cls.take, (reshape826, get_indices_170), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1044 = R.call_tir(cls.dequantize_group_gemm, (take72, model_layers_22_mlp_moe_gate_up_proj_q_weight4, model_layers_22_mlp_moe_gate_up_proj_q_scale4, lv1043), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv202 = R.call_tir(cls.fused_split_silu_multiply, (lv1044,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv1045 = R.call_tir(cls.dequantize_group_gemm1, (lv202, model_layers_22_mlp_moe_down_proj_q_weight4, model_layers_22_mlp_moe_down_proj_q_scale4, lv1043), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1046 = R.call_tir(cls.scatter_output, (lv1045, get_indices_070), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape828 = R.call_tir(cls.reshape6, (top4_softmax_094,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape829 = R.call_tir(cls.reshape7, (lv1046,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv203 = R.call_tir(cls.fused_multiply1_sum, (reshape829, reshape828), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv45_2 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_22_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_22_mlp_shared_expert_gate_up_proj_q_scale4, reshape826), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv204 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv45_2,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv205 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape826, model_layers_22_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv22_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_22_mlp_shared_expert_down_proj_q_weight4, model_layers_22_mlp_shared_expert_down_proj_q_scale4, lv204, lv205, lv203), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape830 = R.call_tir(cls.reshape8, (lv22_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv90 = R.call_tir(cls.fuse_add_norm_decode, (reshape830, lv89, model_layers_23_input_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv91_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90[1]
            rms_norm193: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv90[0]
            lv23_4 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul_add, (model_layers_23_self_attn_c_attn_q_weight4, model_layers_23_self_attn_c_attn_q_scale4, rms_norm193, model_layers_23_self_attn_c_attn_bias4), out_sinfo=R.Tensor((batch_size, 1, 6144), dtype="float16"))
            reshape831 = R.call_tir(cls.reshape, (lv23_4,), out_sinfo=R.Tensor((batch_size, 1, 48, 128), dtype="float16"))
            reshape832 = R.call_tir(cls.reshape1, (reshape831,), out_sinfo=R.Tensor((batch_size, 48, 128), dtype="float16"))
            lv1050 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape832), out_sinfo=R.Tensor((batch_size, 16, 128), dtype="float16"))
            reshape833 = R.call_tir(cls.reshape2, (lv1050,), out_sinfo=R.Tensor((batch_size, 1, 16, 128), dtype="float16"))
            reshape834 = R.call_tir(cls.reshape3, (reshape833,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv46_2 = R.call_tir(cls.fused_dequantize2_NT_matmul1, (model_layers_23_self_attn_o_proj_q_weight4, model_layers_23_self_attn_o_proj_q_scale4, reshape834), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv92_1 = R.call_tir(cls.fuse_add_norm_decode, (lv46_2, lv91_1, model_layers_23_post_attention_layernorm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            lv93_1: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92_1[1]
            rms_norm194: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv92_1[0]
            reshape835 = R.call_tir(cls.reshape4, (rms_norm194,), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv208 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape835, model_layers_23_mlp_gate_weight4), out_sinfo=R.Tensor((batch_size, 60), dtype="float32"))
            lv209 = R.call_tir(cls.fused_softmax_cast1, (lv208,), out_sinfo=R.Tensor((batch_size, 60), dtype="float16"))
            lv1052 = R.call_tir(cls.top4_softmax, (lv209,), out_sinfo=[R.Tensor((batch_size, 4), dtype="float16"), R.Tensor((batch_size, 4), dtype="int32")])
            top4_softmax_095: R.Tensor((batch_size, 4), dtype="float16") = lv1052[0]
            top4_softmax_195: R.Tensor((batch_size, 4), dtype="int32") = lv1052[1]
            lv210 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_195,), out_sinfo=R.Tensor((60, batch_size), dtype="int32"))
            reshape836 = R.call_tir(cls.reshape5, (lv210,), out_sinfo=R.Tensor((batch_size * 60,), dtype="int32"))
            lv46_3: R.Tensor((1, batch_size * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape836, R.shape([1, batch_size * 60]), sinfo_args=(R.Tensor((1, batch_size * 60), dtype="int32"),))
            lv47_2 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv46_3,), out_sinfo=R.Tensor((1, batch_size * 60), dtype="int32"))
            cumsum71: R.Tensor((batch_size * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv47_2, R.shape([batch_size * 60]), sinfo_args=(R.Tensor((batch_size * 60,), dtype="int32"),))
            lv1054 = R.call_tir(cls.get_indices, (cumsum71, top4_softmax_195), out_sinfo=[R.Tensor((batch_size * 4,), dtype="int32"), R.Tensor((batch_size * 4,), dtype="int32")])
            get_indices_071: R.Tensor((batch_size * 4,), dtype="int32") = lv1054[0]
            get_indices_171: R.Tensor((batch_size * 4,), dtype="int32") = lv1054[1]
            lv1055 = R.call_tir(cls.get_expert_instance_indptr, (cumsum71,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([batch_size]))
            take73 = R.call_tir(cls.take, (reshape835, get_indices_171), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1056 = R.call_tir(cls.dequantize_group_gemm, (take73, model_layers_23_mlp_moe_gate_up_proj_q_weight4, model_layers_23_mlp_moe_gate_up_proj_q_scale4, lv1055), out_sinfo=R.Tensor((batch_size * 4, 2816), dtype="float16"))
            lv211 = R.call_tir(cls.fused_split_silu_multiply, (lv1056,), out_sinfo=R.Tensor((batch_size * 4, 1408), dtype="float16"), tir_vars=R.shape([batch_size]))
            lv1057 = R.call_tir(cls.dequantize_group_gemm1, (lv211, model_layers_23_mlp_moe_down_proj_q_weight4, model_layers_23_mlp_moe_down_proj_q_scale4, lv1055), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            lv1058 = R.call_tir(cls.scatter_output, (lv1057, get_indices_071), out_sinfo=R.Tensor((batch_size * 4, 2048), dtype="float16"))
            reshape837 = R.call_tir(cls.reshape6, (top4_softmax_095,), out_sinfo=R.Tensor((batch_size, 4, 1), dtype="float16"))
            reshape838 = R.call_tir(cls.reshape7, (lv1058,), out_sinfo=R.Tensor((batch_size, 4, 2048), dtype="float16"))
            lv212 = R.call_tir(cls.fused_multiply1_sum, (reshape838, reshape837), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            lv47_3 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_23_mlp_shared_expert_gate_up_proj_q_weight4, model_layers_23_mlp_shared_expert_gate_up_proj_q_scale4, reshape835), out_sinfo=R.Tensor((batch_size, 11264), dtype="float16"))
            lv213 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv47_3,), out_sinfo=R.Tensor((batch_size, 5632), dtype="float16"))
            lv214 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape835, model_layers_23_mlp_shared_expert_gate_weight4), out_sinfo=R.Tensor((batch_size, 1), dtype="float16"))
            lv23_5 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_23_mlp_shared_expert_down_proj_q_weight4, model_layers_23_mlp_shared_expert_down_proj_q_scale4, lv213, lv214, lv212), out_sinfo=R.Tensor((batch_size, 2048), dtype="float16"))
            reshape839 = R.call_tir(cls.reshape8, (lv23_5,), out_sinfo=R.Tensor((batch_size, 1, 2048), dtype="float16"))
            lv94_1 = R.call_tir(cls.fuse_add_norm_decode, (reshape839, lv93_1, model_norm_weight4), out_sinfo=[R.Tensor((batch_size, 1, 2048), dtype="float16"), R.Tensor((batch_size, 1, 2048), dtype="float16")])
            rms_norm195: R.Tensor((batch_size, 1, 2048), dtype="float16") = lv94_1[0]
            lv48_2 = R.call_tir(cls.fused_dequantize_fused_NT_matmul6_cast2, (lm_head_q_weight4, lm_head_q_scale4, rms_norm195), out_sinfo=R.Tensor((batch_size, 1, 151936), dtype="float32"))
            gv4: R.Tuple(R.Tensor((batch_size, 1, 151936), dtype="float32"), R.Object) = lv48_2, paged_kv_cache
            R.output(gv4)
        return gv4

    @R.function
    def batch_prefill(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), logit_positions: R.Tensor(("batch_size",), dtype="int32"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, "batch_size", 151936), dtype="float32"), R.Object):
        batch_size = T.int64()
        seq_len = T.int64()
        R.func_attr({"num_input": 3, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 32768, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_layers_0_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[10]
            model_layers_0_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[11]
            model_layers_0_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[12]
            model_layers_0_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[13]
            model_layers_0_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[14]
            model_layers_0_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[15]
            model_layers_0_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[16]
            model_layers_0_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[17]
            model_layers_0_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[18]
            model_layers_1_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[19]
            model_layers_1_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[20]
            model_layers_1_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[21]
            model_layers_1_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[22]
            model_layers_1_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[23]
            model_layers_1_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[24]
            model_layers_1_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[25]
            model_layers_1_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[26]
            model_layers_1_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[27]
            model_layers_1_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[28]
            model_layers_1_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[29]
            model_layers_1_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[30]
            model_layers_1_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[31]
            model_layers_1_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[32]
            model_layers_1_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[33]
            model_layers_1_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_1_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[35]
            model_layers_2_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[36]
            model_layers_2_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[37]
            model_layers_2_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[38]
            model_layers_2_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[39]
            model_layers_2_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[40]
            model_layers_2_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[41]
            model_layers_2_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[42]
            model_layers_2_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[43]
            model_layers_2_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[44]
            model_layers_2_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[45]
            model_layers_2_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[46]
            model_layers_2_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[47]
            model_layers_2_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[48]
            model_layers_2_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[49]
            model_layers_2_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[50]
            model_layers_2_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[51]
            model_layers_2_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[52]
            model_layers_3_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[53]
            model_layers_3_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[54]
            model_layers_3_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[55]
            model_layers_3_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[56]
            model_layers_3_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[57]
            model_layers_3_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[58]
            model_layers_3_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[59]
            model_layers_3_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[60]
            model_layers_3_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[61]
            model_layers_3_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[62]
            model_layers_3_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[63]
            model_layers_3_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[64]
            model_layers_3_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[65]
            model_layers_3_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[66]
            model_layers_3_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[67]
            model_layers_3_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[68]
            model_layers_3_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[69]
            model_layers_4_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[70]
            model_layers_4_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[71]
            model_layers_4_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[72]
            model_layers_4_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[73]
            model_layers_4_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[74]
            model_layers_4_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[75]
            model_layers_4_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[76]
            model_layers_4_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[77]
            model_layers_4_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[78]
            model_layers_4_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[79]
            model_layers_4_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[80]
            model_layers_4_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[81]
            model_layers_4_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[82]
            model_layers_4_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[83]
            model_layers_4_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[84]
            model_layers_4_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[85]
            model_layers_4_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[86]
            model_layers_5_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[87]
            model_layers_5_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[88]
            model_layers_5_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[89]
            model_layers_5_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[90]
            model_layers_5_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[91]
            model_layers_5_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[92]
            model_layers_5_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[93]
            model_layers_5_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[94]
            model_layers_5_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[95]
            model_layers_5_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[96]
            model_layers_5_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[97]
            model_layers_5_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[98]
            model_layers_5_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[99]
            model_layers_5_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[100]
            model_layers_5_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[101]
            model_layers_5_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[102]
            model_layers_5_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[103]
            model_layers_6_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[104]
            model_layers_6_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[105]
            model_layers_6_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[106]
            model_layers_6_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[107]
            model_layers_6_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[108]
            model_layers_6_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[109]
            model_layers_6_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[110]
            model_layers_6_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[111]
            model_layers_6_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[112]
            model_layers_6_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[113]
            model_layers_6_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[114]
            model_layers_6_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[115]
            model_layers_6_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[116]
            model_layers_6_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[117]
            model_layers_6_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[118]
            model_layers_6_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[119]
            model_layers_6_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[120]
            model_layers_7_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[121]
            model_layers_7_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[122]
            model_layers_7_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[123]
            model_layers_7_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[124]
            model_layers_7_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[125]
            model_layers_7_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[126]
            model_layers_7_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[127]
            model_layers_7_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[128]
            model_layers_7_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[129]
            model_layers_7_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[130]
            model_layers_7_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[131]
            model_layers_7_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[132]
            model_layers_7_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[133]
            model_layers_7_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[134]
            model_layers_7_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[135]
            model_layers_7_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[136]
            model_layers_7_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[137]
            model_layers_8_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[138]
            model_layers_8_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[139]
            model_layers_8_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[140]
            model_layers_8_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[141]
            model_layers_8_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[142]
            model_layers_8_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[143]
            model_layers_8_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[144]
            model_layers_8_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[145]
            model_layers_8_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[146]
            model_layers_8_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[147]
            model_layers_8_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[148]
            model_layers_8_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[149]
            model_layers_8_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[150]
            model_layers_8_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[151]
            model_layers_8_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[152]
            model_layers_8_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[153]
            model_layers_8_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_9_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[155]
            model_layers_9_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[156]
            model_layers_9_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[157]
            model_layers_9_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[158]
            model_layers_9_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[159]
            model_layers_9_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[160]
            model_layers_9_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[161]
            model_layers_9_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[162]
            model_layers_9_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[163]
            model_layers_9_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[164]
            model_layers_9_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[165]
            model_layers_9_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[166]
            model_layers_9_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[167]
            model_layers_9_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[168]
            model_layers_9_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[169]
            model_layers_9_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[170]
            model_layers_9_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[171]
            model_layers_10_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[172]
            model_layers_10_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[173]
            model_layers_10_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[174]
            model_layers_10_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[175]
            model_layers_10_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[176]
            model_layers_10_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[177]
            model_layers_10_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[178]
            model_layers_10_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[179]
            model_layers_10_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[180]
            model_layers_10_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[181]
            model_layers_10_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[182]
            model_layers_10_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[183]
            model_layers_10_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[184]
            model_layers_10_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[185]
            model_layers_10_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[186]
            model_layers_10_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_10_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_11_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[189]
            model_layers_11_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[190]
            model_layers_11_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[191]
            model_layers_11_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_11_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_11_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[194]
            model_layers_11_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[195]
            model_layers_11_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[196]
            model_layers_11_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[197]
            model_layers_11_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[198]
            model_layers_11_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[199]
            model_layers_11_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[200]
            model_layers_11_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[201]
            model_layers_11_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[202]
            model_layers_11_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[203]
            model_layers_11_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[204]
            model_layers_11_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[205]
            model_layers_12_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[206]
            model_layers_12_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[207]
            model_layers_12_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[208]
            model_layers_12_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[209]
            model_layers_12_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[210]
            model_layers_12_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[211]
            model_layers_12_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[212]
            model_layers_12_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[213]
            model_layers_12_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[214]
            model_layers_12_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[215]
            model_layers_12_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[216]
            model_layers_12_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[217]
            model_layers_12_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[218]
            model_layers_12_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[219]
            model_layers_12_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[220]
            model_layers_12_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_12_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[222]
            model_layers_13_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[223]
            model_layers_13_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[224]
            model_layers_13_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[225]
            model_layers_13_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[226]
            model_layers_13_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[227]
            model_layers_13_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[228]
            model_layers_13_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[229]
            model_layers_13_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[230]
            model_layers_13_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[231]
            model_layers_13_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[232]
            model_layers_13_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[233]
            model_layers_13_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[234]
            model_layers_13_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[235]
            model_layers_13_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[236]
            model_layers_13_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[237]
            model_layers_13_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[238]
            model_layers_13_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[239]
            model_layers_14_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[240]
            model_layers_14_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[241]
            model_layers_14_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[242]
            model_layers_14_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[243]
            model_layers_14_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[244]
            model_layers_14_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[245]
            model_layers_14_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[246]
            model_layers_14_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[247]
            model_layers_14_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[248]
            model_layers_14_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[249]
            model_layers_14_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[250]
            model_layers_14_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[251]
            model_layers_14_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[252]
            model_layers_14_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[253]
            model_layers_14_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[254]
            model_layers_14_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[255]
            model_layers_14_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[256]
            model_layers_15_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[257]
            model_layers_15_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[258]
            model_layers_15_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[259]
            model_layers_15_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[260]
            model_layers_15_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[261]
            model_layers_15_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[262]
            model_layers_15_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[263]
            model_layers_15_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[264]
            model_layers_15_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[265]
            model_layers_15_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[266]
            model_layers_15_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[267]
            model_layers_15_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[268]
            model_layers_15_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[269]
            model_layers_15_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[270]
            model_layers_15_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[271]
            model_layers_15_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[272]
            model_layers_15_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[273]
            model_layers_16_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[274]
            model_layers_16_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[275]
            model_layers_16_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[276]
            model_layers_16_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[277]
            model_layers_16_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[278]
            model_layers_16_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[279]
            model_layers_16_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[280]
            model_layers_16_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[281]
            model_layers_16_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[282]
            model_layers_16_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[283]
            model_layers_16_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[284]
            model_layers_16_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[285]
            model_layers_16_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[286]
            model_layers_16_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[287]
            model_layers_16_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[288]
            model_layers_16_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[289]
            model_layers_16_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[290]
            model_layers_17_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[291]
            model_layers_17_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[292]
            model_layers_17_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[293]
            model_layers_17_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[294]
            model_layers_17_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[295]
            model_layers_17_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[296]
            model_layers_17_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[297]
            model_layers_17_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[298]
            model_layers_17_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[299]
            model_layers_17_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[300]
            model_layers_17_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[301]
            model_layers_17_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[302]
            model_layers_17_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[303]
            model_layers_17_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[304]
            model_layers_17_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[305]
            model_layers_17_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[306]
            model_layers_17_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[307]
            model_layers_18_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[308]
            model_layers_18_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[309]
            model_layers_18_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[310]
            model_layers_18_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[311]
            model_layers_18_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[312]
            model_layers_18_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[313]
            model_layers_18_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[314]
            model_layers_18_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[315]
            model_layers_18_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[316]
            model_layers_18_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[317]
            model_layers_18_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[318]
            model_layers_18_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[319]
            model_layers_18_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[320]
            model_layers_18_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[321]
            model_layers_18_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[322]
            model_layers_18_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[323]
            model_layers_18_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[324]
            model_layers_19_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[325]
            model_layers_19_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[326]
            model_layers_19_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[327]
            model_layers_19_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[328]
            model_layers_19_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[329]
            model_layers_19_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[330]
            model_layers_19_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[331]
            model_layers_19_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[332]
            model_layers_19_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[333]
            model_layers_19_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[334]
            model_layers_19_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[335]
            model_layers_19_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[336]
            model_layers_19_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[337]
            model_layers_19_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[338]
            model_layers_19_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[339]
            model_layers_19_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[340]
            model_layers_19_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_20_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[342]
            model_layers_20_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[343]
            model_layers_20_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[344]
            model_layers_20_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[345]
            model_layers_20_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[346]
            model_layers_20_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[347]
            model_layers_20_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[348]
            model_layers_20_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[349]
            model_layers_20_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[350]
            model_layers_20_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[351]
            model_layers_20_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[352]
            model_layers_20_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[353]
            model_layers_20_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[354]
            model_layers_20_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[355]
            model_layers_20_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[356]
            model_layers_20_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[357]
            model_layers_20_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[358]
            model_layers_21_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[359]
            model_layers_21_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[360]
            model_layers_21_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[361]
            model_layers_21_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[362]
            model_layers_21_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[363]
            model_layers_21_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[364]
            model_layers_21_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[365]
            model_layers_21_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[366]
            model_layers_21_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[367]
            model_layers_21_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[368]
            model_layers_21_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[369]
            model_layers_21_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[370]
            model_layers_21_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[371]
            model_layers_21_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[372]
            model_layers_21_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[373]
            model_layers_21_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_21_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_22_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[376]
            model_layers_22_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[377]
            model_layers_22_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[378]
            model_layers_22_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_22_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_22_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[381]
            model_layers_22_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[382]
            model_layers_22_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[383]
            model_layers_22_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[384]
            model_layers_22_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[385]
            model_layers_22_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[386]
            model_layers_22_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[387]
            model_layers_22_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[388]
            model_layers_22_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[389]
            model_layers_22_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[390]
            model_layers_22_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[391]
            model_layers_22_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[392]
            model_layers_23_self_attn_c_attn_q_weight3: R.Tensor((6144, 256), dtype="uint32") = packed_params[393]
            model_layers_23_self_attn_c_attn_q_scale3: R.Tensor((6144, 64), dtype="float16") = packed_params[394]
            model_layers_23_self_attn_c_attn_bias3: R.Tensor((6144,), dtype="float16") = packed_params[395]
            model_layers_23_self_attn_o_proj_q_weight3: R.Tensor((2048, 256), dtype="uint32") = packed_params[396]
            model_layers_23_self_attn_o_proj_q_scale3: R.Tensor((2048, 64), dtype="float16") = packed_params[397]
            model_layers_23_mlp_shared_expert_gate_up_proj_q_weight3: R.Tensor((11264, 256), dtype="uint32") = packed_params[398]
            model_layers_23_mlp_shared_expert_gate_up_proj_q_scale3: R.Tensor((11264, 64), dtype="float16") = packed_params[399]
            model_layers_23_mlp_shared_expert_down_proj_q_weight3: R.Tensor((2048, 704), dtype="uint32") = packed_params[400]
            model_layers_23_mlp_shared_expert_down_proj_q_scale3: R.Tensor((2048, 176), dtype="float16") = packed_params[401]
            model_layers_23_mlp_shared_expert_gate_weight3: R.Tensor((1, 2048), dtype="float16") = packed_params[402]
            model_layers_23_mlp_gate_weight3: R.Tensor((60, 2048), dtype="float16") = packed_params[403]
            model_layers_23_mlp_moe_gate_up_proj_q_weight3: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[404]
            model_layers_23_mlp_moe_gate_up_proj_q_scale3: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[405]
            model_layers_23_mlp_moe_down_proj_q_weight3: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[406]
            model_layers_23_mlp_moe_down_proj_q_scale3: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[407]
            model_layers_23_input_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[408]
            model_layers_23_post_attention_layernorm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[409]
            model_norm_weight3: R.Tensor((2048,), dtype="float16") = packed_params[410]
            lm_head_q_weight3: R.Tensor((151936, 256), dtype="uint32") = packed_params[411]
            lm_head_q_scale3: R.Tensor((151936, 64), dtype="float16") = packed_params[412]
            rms_norm98 = R.call_tir(cls.rms_norm1, (input_embeds, model_layers_0_input_layernorm_weight3), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv24 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_0_self_attn_c_attn_q_weight3, model_layers_0_self_attn_c_attn_q_scale3, rms_norm98, model_layers_0_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape408 = R.call_tir(cls.reshape9, (lv24,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape409 = R.call_tir(cls.reshape10, (reshape408,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv485 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape409), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape410 = R.call_tir(cls.reshape11, (lv485,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape411 = R.call_tir(cls.reshape12, (reshape410,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv49 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_0_self_attn_o_proj_q_weight3, model_layers_0_self_attn_o_proj_q_scale3, reshape411), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv96 = R.call_tir(cls.fuse_add_norm_prefill, (lv49, input_embeds, model_layers_0_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv97: R.Tensor((1, seq_len, 2048), dtype="float16") = lv96[1]
            rms_norm99: R.Tensor((1, seq_len, 2048), dtype="float16") = lv96[0]
            reshape412 = R.call_tir(cls.reshape13, (rms_norm99,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv218 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape412, model_layers_0_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv219 = R.call_tir(cls.fused_softmax_cast1, (lv218,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv487 = R.call_tir(cls.top4_softmax, (lv219,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_048: R.Tensor((seq_len, 4), dtype="float16") = lv487[0]
            top4_softmax_148: R.Tensor((seq_len, 4), dtype="int32") = lv487[1]
            lv220 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_148,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape413 = R.call_tir(cls.reshape5, (lv220,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv48: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape413, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv49_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv48,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum24: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv49_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv489 = R.call_tir(cls.get_indices, (cumsum24, top4_softmax_148), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_024: R.Tensor((seq_len * 4,), dtype="int32") = lv489[0]
            get_indices_124: R.Tensor((seq_len * 4,), dtype="int32") = lv489[1]
            lv490 = R.call_tir(cls.get_expert_instance_indptr, (cumsum24,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take25 = R.call_tir(cls.take, (reshape412, get_indices_124), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv491 = R.call_tir(cls.dequantize_group_gemm, (take25, model_layers_0_mlp_moe_gate_up_proj_q_weight3, model_layers_0_mlp_moe_gate_up_proj_q_scale3, lv490), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv221 = R.call_tir(cls.fused_split_silu_multiply, (lv491,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv492 = R.call_tir(cls.dequantize_group_gemm1, (lv221, model_layers_0_mlp_moe_down_proj_q_weight3, model_layers_0_mlp_moe_down_proj_q_scale3, lv490), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv493 = R.call_tir(cls.scatter_output, (lv492, get_indices_024), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape414 = R.call_tir(cls.reshape6, (top4_softmax_048,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape415 = R.call_tir(cls.reshape7, (lv493,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv222 = R.call_tir(cls.fused_multiply1_sum, (reshape415, reshape414), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv50 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_0_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_0_mlp_shared_expert_gate_up_proj_q_scale3, reshape412), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv223 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv50,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv224 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape412, model_layers_0_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv24_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_0_mlp_shared_expert_down_proj_q_weight3, model_layers_0_mlp_shared_expert_down_proj_q_scale3, lv223, lv224, lv222), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape416 = R.call_tir(cls.reshape14, (lv24_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv98 = R.call_tir(cls.fuse_add_norm_prefill, (reshape416, lv97, model_layers_1_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv99: R.Tensor((1, seq_len, 2048), dtype="float16") = lv98[1]
            rms_norm100: R.Tensor((1, seq_len, 2048), dtype="float16") = lv98[0]
            lv25 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_1_self_attn_c_attn_q_weight3, model_layers_1_self_attn_c_attn_q_scale3, rms_norm100, model_layers_1_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape417 = R.call_tir(cls.reshape9, (lv25,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape418 = R.call_tir(cls.reshape10, (reshape417,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv497 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape418), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape419 = R.call_tir(cls.reshape11, (lv497,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape420 = R.call_tir(cls.reshape12, (reshape419,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv51 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_1_self_attn_o_proj_q_weight3, model_layers_1_self_attn_o_proj_q_scale3, reshape420), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv100 = R.call_tir(cls.fuse_add_norm_prefill, (lv51, lv99, model_layers_1_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv101: R.Tensor((1, seq_len, 2048), dtype="float16") = lv100[1]
            rms_norm101: R.Tensor((1, seq_len, 2048), dtype="float16") = lv100[0]
            reshape421 = R.call_tir(cls.reshape13, (rms_norm101,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv227 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape421, model_layers_1_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv228 = R.call_tir(cls.fused_softmax_cast1, (lv227,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv499 = R.call_tir(cls.top4_softmax, (lv228,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_049: R.Tensor((seq_len, 4), dtype="float16") = lv499[0]
            top4_softmax_149: R.Tensor((seq_len, 4), dtype="int32") = lv499[1]
            lv229 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_149,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape422 = R.call_tir(cls.reshape5, (lv229,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv50_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape422, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv51_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv50_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum25: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv51_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv501 = R.call_tir(cls.get_indices, (cumsum25, top4_softmax_149), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_025: R.Tensor((seq_len * 4,), dtype="int32") = lv501[0]
            get_indices_125: R.Tensor((seq_len * 4,), dtype="int32") = lv501[1]
            lv502 = R.call_tir(cls.get_expert_instance_indptr, (cumsum25,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take26 = R.call_tir(cls.take, (reshape421, get_indices_125), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv503 = R.call_tir(cls.dequantize_group_gemm, (take26, model_layers_1_mlp_moe_gate_up_proj_q_weight3, model_layers_1_mlp_moe_gate_up_proj_q_scale3, lv502), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv230 = R.call_tir(cls.fused_split_silu_multiply, (lv503,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv504 = R.call_tir(cls.dequantize_group_gemm1, (lv230, model_layers_1_mlp_moe_down_proj_q_weight3, model_layers_1_mlp_moe_down_proj_q_scale3, lv502), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv505 = R.call_tir(cls.scatter_output, (lv504, get_indices_025), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape423 = R.call_tir(cls.reshape6, (top4_softmax_049,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape424 = R.call_tir(cls.reshape7, (lv505,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv231 = R.call_tir(cls.fused_multiply1_sum, (reshape424, reshape423), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv52 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_1_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_1_mlp_shared_expert_gate_up_proj_q_scale3, reshape421), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv232 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv52,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv233 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape421, model_layers_1_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv25_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_1_mlp_shared_expert_down_proj_q_weight3, model_layers_1_mlp_shared_expert_down_proj_q_scale3, lv232, lv233, lv231), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape425 = R.call_tir(cls.reshape14, (lv25_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv102 = R.call_tir(cls.fuse_add_norm_prefill, (reshape425, lv101, model_layers_2_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv103: R.Tensor((1, seq_len, 2048), dtype="float16") = lv102[1]
            rms_norm102: R.Tensor((1, seq_len, 2048), dtype="float16") = lv102[0]
            lv26 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_2_self_attn_c_attn_q_weight3, model_layers_2_self_attn_c_attn_q_scale3, rms_norm102, model_layers_2_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape426 = R.call_tir(cls.reshape9, (lv26,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape427 = R.call_tir(cls.reshape10, (reshape426,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv509 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape427), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape428 = R.call_tir(cls.reshape11, (lv509,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape429 = R.call_tir(cls.reshape12, (reshape428,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv53 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_2_self_attn_o_proj_q_weight3, model_layers_2_self_attn_o_proj_q_scale3, reshape429), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv104 = R.call_tir(cls.fuse_add_norm_prefill, (lv53, lv103, model_layers_2_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv105: R.Tensor((1, seq_len, 2048), dtype="float16") = lv104[1]
            rms_norm103: R.Tensor((1, seq_len, 2048), dtype="float16") = lv104[0]
            reshape430 = R.call_tir(cls.reshape13, (rms_norm103,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv236 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape430, model_layers_2_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv237 = R.call_tir(cls.fused_softmax_cast1, (lv236,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv511 = R.call_tir(cls.top4_softmax, (lv237,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_050: R.Tensor((seq_len, 4), dtype="float16") = lv511[0]
            top4_softmax_150: R.Tensor((seq_len, 4), dtype="int32") = lv511[1]
            lv238 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_150,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape431 = R.call_tir(cls.reshape5, (lv238,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv52_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape431, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv53_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv52_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum26: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv53_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv513 = R.call_tir(cls.get_indices, (cumsum26, top4_softmax_150), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_026: R.Tensor((seq_len * 4,), dtype="int32") = lv513[0]
            get_indices_126: R.Tensor((seq_len * 4,), dtype="int32") = lv513[1]
            lv514 = R.call_tir(cls.get_expert_instance_indptr, (cumsum26,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take27 = R.call_tir(cls.take, (reshape430, get_indices_126), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv515 = R.call_tir(cls.dequantize_group_gemm, (take27, model_layers_2_mlp_moe_gate_up_proj_q_weight3, model_layers_2_mlp_moe_gate_up_proj_q_scale3, lv514), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv239 = R.call_tir(cls.fused_split_silu_multiply, (lv515,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv516 = R.call_tir(cls.dequantize_group_gemm1, (lv239, model_layers_2_mlp_moe_down_proj_q_weight3, model_layers_2_mlp_moe_down_proj_q_scale3, lv514), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv517 = R.call_tir(cls.scatter_output, (lv516, get_indices_026), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape432 = R.call_tir(cls.reshape6, (top4_softmax_050,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape433 = R.call_tir(cls.reshape7, (lv517,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv240 = R.call_tir(cls.fused_multiply1_sum, (reshape433, reshape432), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv54 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_2_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_2_mlp_shared_expert_gate_up_proj_q_scale3, reshape430), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv241 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv54,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv242 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape430, model_layers_2_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv26_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_2_mlp_shared_expert_down_proj_q_weight3, model_layers_2_mlp_shared_expert_down_proj_q_scale3, lv241, lv242, lv240), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape434 = R.call_tir(cls.reshape14, (lv26_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv106 = R.call_tir(cls.fuse_add_norm_prefill, (reshape434, lv105, model_layers_3_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv107: R.Tensor((1, seq_len, 2048), dtype="float16") = lv106[1]
            rms_norm104: R.Tensor((1, seq_len, 2048), dtype="float16") = lv106[0]
            lv27 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_3_self_attn_c_attn_q_weight3, model_layers_3_self_attn_c_attn_q_scale3, rms_norm104, model_layers_3_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape435 = R.call_tir(cls.reshape9, (lv27,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape436 = R.call_tir(cls.reshape10, (reshape435,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv521 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(3), R.prim_value(T.float32(1.0)), reshape436), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape437 = R.call_tir(cls.reshape11, (lv521,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape438 = R.call_tir(cls.reshape12, (reshape437,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv55 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_3_self_attn_o_proj_q_weight3, model_layers_3_self_attn_o_proj_q_scale3, reshape438), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv108 = R.call_tir(cls.fuse_add_norm_prefill, (lv55, lv107, model_layers_3_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv109: R.Tensor((1, seq_len, 2048), dtype="float16") = lv108[1]
            rms_norm105: R.Tensor((1, seq_len, 2048), dtype="float16") = lv108[0]
            reshape439 = R.call_tir(cls.reshape13, (rms_norm105,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv245 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape439, model_layers_3_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv246 = R.call_tir(cls.fused_softmax_cast1, (lv245,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv523 = R.call_tir(cls.top4_softmax, (lv246,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_051: R.Tensor((seq_len, 4), dtype="float16") = lv523[0]
            top4_softmax_151: R.Tensor((seq_len, 4), dtype="int32") = lv523[1]
            lv247 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_151,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape440 = R.call_tir(cls.reshape5, (lv247,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv54_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape440, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv55_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv54_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum27: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv55_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv525 = R.call_tir(cls.get_indices, (cumsum27, top4_softmax_151), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_027: R.Tensor((seq_len * 4,), dtype="int32") = lv525[0]
            get_indices_127: R.Tensor((seq_len * 4,), dtype="int32") = lv525[1]
            lv526 = R.call_tir(cls.get_expert_instance_indptr, (cumsum27,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take28 = R.call_tir(cls.take, (reshape439, get_indices_127), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv527 = R.call_tir(cls.dequantize_group_gemm, (take28, model_layers_3_mlp_moe_gate_up_proj_q_weight3, model_layers_3_mlp_moe_gate_up_proj_q_scale3, lv526), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv248 = R.call_tir(cls.fused_split_silu_multiply, (lv527,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv528 = R.call_tir(cls.dequantize_group_gemm1, (lv248, model_layers_3_mlp_moe_down_proj_q_weight3, model_layers_3_mlp_moe_down_proj_q_scale3, lv526), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv529 = R.call_tir(cls.scatter_output, (lv528, get_indices_027), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape441 = R.call_tir(cls.reshape6, (top4_softmax_051,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape442 = R.call_tir(cls.reshape7, (lv529,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv249 = R.call_tir(cls.fused_multiply1_sum, (reshape442, reshape441), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv56 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_3_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_3_mlp_shared_expert_gate_up_proj_q_scale3, reshape439), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv250 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv56,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv251 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape439, model_layers_3_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv27_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_3_mlp_shared_expert_down_proj_q_weight3, model_layers_3_mlp_shared_expert_down_proj_q_scale3, lv250, lv251, lv249), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape443 = R.call_tir(cls.reshape14, (lv27_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv110 = R.call_tir(cls.fuse_add_norm_prefill, (reshape443, lv109, model_layers_4_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv111: R.Tensor((1, seq_len, 2048), dtype="float16") = lv110[1]
            rms_norm106: R.Tensor((1, seq_len, 2048), dtype="float16") = lv110[0]
            lv28 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_4_self_attn_c_attn_q_weight3, model_layers_4_self_attn_c_attn_q_scale3, rms_norm106, model_layers_4_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape444 = R.call_tir(cls.reshape9, (lv28,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape445 = R.call_tir(cls.reshape10, (reshape444,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv533 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(4), R.prim_value(T.float32(1.0)), reshape445), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape446 = R.call_tir(cls.reshape11, (lv533,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape447 = R.call_tir(cls.reshape12, (reshape446,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv57 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_4_self_attn_o_proj_q_weight3, model_layers_4_self_attn_o_proj_q_scale3, reshape447), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv112 = R.call_tir(cls.fuse_add_norm_prefill, (lv57, lv111, model_layers_4_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv113: R.Tensor((1, seq_len, 2048), dtype="float16") = lv112[1]
            rms_norm107: R.Tensor((1, seq_len, 2048), dtype="float16") = lv112[0]
            reshape448 = R.call_tir(cls.reshape13, (rms_norm107,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv254 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape448, model_layers_4_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv255 = R.call_tir(cls.fused_softmax_cast1, (lv254,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv535 = R.call_tir(cls.top4_softmax, (lv255,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_052: R.Tensor((seq_len, 4), dtype="float16") = lv535[0]
            top4_softmax_152: R.Tensor((seq_len, 4), dtype="int32") = lv535[1]
            lv256 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_152,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape449 = R.call_tir(cls.reshape5, (lv256,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv56_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape449, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv57_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv56_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum28: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv57_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv537 = R.call_tir(cls.get_indices, (cumsum28, top4_softmax_152), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_028: R.Tensor((seq_len * 4,), dtype="int32") = lv537[0]
            get_indices_128: R.Tensor((seq_len * 4,), dtype="int32") = lv537[1]
            lv538 = R.call_tir(cls.get_expert_instance_indptr, (cumsum28,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take29 = R.call_tir(cls.take, (reshape448, get_indices_128), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv539 = R.call_tir(cls.dequantize_group_gemm, (take29, model_layers_4_mlp_moe_gate_up_proj_q_weight3, model_layers_4_mlp_moe_gate_up_proj_q_scale3, lv538), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv257 = R.call_tir(cls.fused_split_silu_multiply, (lv539,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv540 = R.call_tir(cls.dequantize_group_gemm1, (lv257, model_layers_4_mlp_moe_down_proj_q_weight3, model_layers_4_mlp_moe_down_proj_q_scale3, lv538), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv541 = R.call_tir(cls.scatter_output, (lv540, get_indices_028), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape450 = R.call_tir(cls.reshape6, (top4_softmax_052,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape451 = R.call_tir(cls.reshape7, (lv541,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv258 = R.call_tir(cls.fused_multiply1_sum, (reshape451, reshape450), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv58 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_4_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_4_mlp_shared_expert_gate_up_proj_q_scale3, reshape448), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv259 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv58,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv260 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape448, model_layers_4_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv28_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_4_mlp_shared_expert_down_proj_q_weight3, model_layers_4_mlp_shared_expert_down_proj_q_scale3, lv259, lv260, lv258), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape452 = R.call_tir(cls.reshape14, (lv28_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv114 = R.call_tir(cls.fuse_add_norm_prefill, (reshape452, lv113, model_layers_5_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv115: R.Tensor((1, seq_len, 2048), dtype="float16") = lv114[1]
            rms_norm108: R.Tensor((1, seq_len, 2048), dtype="float16") = lv114[0]
            lv29 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_5_self_attn_c_attn_q_weight3, model_layers_5_self_attn_c_attn_q_scale3, rms_norm108, model_layers_5_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape453 = R.call_tir(cls.reshape9, (lv29,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape454 = R.call_tir(cls.reshape10, (reshape453,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv545 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(5), R.prim_value(T.float32(1.0)), reshape454), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape455 = R.call_tir(cls.reshape11, (lv545,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape456 = R.call_tir(cls.reshape12, (reshape455,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv59 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_5_self_attn_o_proj_q_weight3, model_layers_5_self_attn_o_proj_q_scale3, reshape456), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv116 = R.call_tir(cls.fuse_add_norm_prefill, (lv59, lv115, model_layers_5_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv117: R.Tensor((1, seq_len, 2048), dtype="float16") = lv116[1]
            rms_norm109: R.Tensor((1, seq_len, 2048), dtype="float16") = lv116[0]
            reshape457 = R.call_tir(cls.reshape13, (rms_norm109,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv263 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape457, model_layers_5_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv264 = R.call_tir(cls.fused_softmax_cast1, (lv263,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv547 = R.call_tir(cls.top4_softmax, (lv264,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_053: R.Tensor((seq_len, 4), dtype="float16") = lv547[0]
            top4_softmax_153: R.Tensor((seq_len, 4), dtype="int32") = lv547[1]
            lv265 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_153,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape458 = R.call_tir(cls.reshape5, (lv265,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv58_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape458, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv59_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv58_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum29: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv59_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv549 = R.call_tir(cls.get_indices, (cumsum29, top4_softmax_153), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_029: R.Tensor((seq_len * 4,), dtype="int32") = lv549[0]
            get_indices_129: R.Tensor((seq_len * 4,), dtype="int32") = lv549[1]
            lv550 = R.call_tir(cls.get_expert_instance_indptr, (cumsum29,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take30 = R.call_tir(cls.take, (reshape457, get_indices_129), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv551 = R.call_tir(cls.dequantize_group_gemm, (take30, model_layers_5_mlp_moe_gate_up_proj_q_weight3, model_layers_5_mlp_moe_gate_up_proj_q_scale3, lv550), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv266 = R.call_tir(cls.fused_split_silu_multiply, (lv551,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv552 = R.call_tir(cls.dequantize_group_gemm1, (lv266, model_layers_5_mlp_moe_down_proj_q_weight3, model_layers_5_mlp_moe_down_proj_q_scale3, lv550), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv553 = R.call_tir(cls.scatter_output, (lv552, get_indices_029), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape459 = R.call_tir(cls.reshape6, (top4_softmax_053,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape460 = R.call_tir(cls.reshape7, (lv553,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv267 = R.call_tir(cls.fused_multiply1_sum, (reshape460, reshape459), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv60 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_5_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_5_mlp_shared_expert_gate_up_proj_q_scale3, reshape457), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv268 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv60,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv269 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape457, model_layers_5_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv29_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_5_mlp_shared_expert_down_proj_q_weight3, model_layers_5_mlp_shared_expert_down_proj_q_scale3, lv268, lv269, lv267), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape461 = R.call_tir(cls.reshape14, (lv29_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv118 = R.call_tir(cls.fuse_add_norm_prefill, (reshape461, lv117, model_layers_6_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv119: R.Tensor((1, seq_len, 2048), dtype="float16") = lv118[1]
            rms_norm110: R.Tensor((1, seq_len, 2048), dtype="float16") = lv118[0]
            lv30 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_6_self_attn_c_attn_q_weight3, model_layers_6_self_attn_c_attn_q_scale3, rms_norm110, model_layers_6_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape462 = R.call_tir(cls.reshape9, (lv30,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape463 = R.call_tir(cls.reshape10, (reshape462,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv557 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(6), R.prim_value(T.float32(1.0)), reshape463), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape464 = R.call_tir(cls.reshape11, (lv557,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape465 = R.call_tir(cls.reshape12, (reshape464,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv61 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_6_self_attn_o_proj_q_weight3, model_layers_6_self_attn_o_proj_q_scale3, reshape465), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv120 = R.call_tir(cls.fuse_add_norm_prefill, (lv61, lv119, model_layers_6_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv121: R.Tensor((1, seq_len, 2048), dtype="float16") = lv120[1]
            rms_norm111: R.Tensor((1, seq_len, 2048), dtype="float16") = lv120[0]
            reshape466 = R.call_tir(cls.reshape13, (rms_norm111,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv272 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape466, model_layers_6_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv273 = R.call_tir(cls.fused_softmax_cast1, (lv272,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv559 = R.call_tir(cls.top4_softmax, (lv273,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_054: R.Tensor((seq_len, 4), dtype="float16") = lv559[0]
            top4_softmax_154: R.Tensor((seq_len, 4), dtype="int32") = lv559[1]
            lv274 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_154,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape467 = R.call_tir(cls.reshape5, (lv274,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv60_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape467, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv61_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv60_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum30: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv61_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv561 = R.call_tir(cls.get_indices, (cumsum30, top4_softmax_154), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_030: R.Tensor((seq_len * 4,), dtype="int32") = lv561[0]
            get_indices_130: R.Tensor((seq_len * 4,), dtype="int32") = lv561[1]
            lv562 = R.call_tir(cls.get_expert_instance_indptr, (cumsum30,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take31 = R.call_tir(cls.take, (reshape466, get_indices_130), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv563 = R.call_tir(cls.dequantize_group_gemm, (take31, model_layers_6_mlp_moe_gate_up_proj_q_weight3, model_layers_6_mlp_moe_gate_up_proj_q_scale3, lv562), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv275 = R.call_tir(cls.fused_split_silu_multiply, (lv563,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv564 = R.call_tir(cls.dequantize_group_gemm1, (lv275, model_layers_6_mlp_moe_down_proj_q_weight3, model_layers_6_mlp_moe_down_proj_q_scale3, lv562), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv565 = R.call_tir(cls.scatter_output, (lv564, get_indices_030), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape468 = R.call_tir(cls.reshape6, (top4_softmax_054,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape469 = R.call_tir(cls.reshape7, (lv565,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv276 = R.call_tir(cls.fused_multiply1_sum, (reshape469, reshape468), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv62 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_6_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_6_mlp_shared_expert_gate_up_proj_q_scale3, reshape466), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv277 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv62,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv278 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape466, model_layers_6_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv30_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_6_mlp_shared_expert_down_proj_q_weight3, model_layers_6_mlp_shared_expert_down_proj_q_scale3, lv277, lv278, lv276), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape470 = R.call_tir(cls.reshape14, (lv30_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv122 = R.call_tir(cls.fuse_add_norm_prefill, (reshape470, lv121, model_layers_7_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv123: R.Tensor((1, seq_len, 2048), dtype="float16") = lv122[1]
            rms_norm112: R.Tensor((1, seq_len, 2048), dtype="float16") = lv122[0]
            lv31 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_7_self_attn_c_attn_q_weight3, model_layers_7_self_attn_c_attn_q_scale3, rms_norm112, model_layers_7_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape471 = R.call_tir(cls.reshape9, (lv31,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape472 = R.call_tir(cls.reshape10, (reshape471,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv569 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(7), R.prim_value(T.float32(1.0)), reshape472), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape473 = R.call_tir(cls.reshape11, (lv569,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape474 = R.call_tir(cls.reshape12, (reshape473,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv63 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_7_self_attn_o_proj_q_weight3, model_layers_7_self_attn_o_proj_q_scale3, reshape474), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv124 = R.call_tir(cls.fuse_add_norm_prefill, (lv63, lv123, model_layers_7_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv125: R.Tensor((1, seq_len, 2048), dtype="float16") = lv124[1]
            rms_norm113: R.Tensor((1, seq_len, 2048), dtype="float16") = lv124[0]
            reshape475 = R.call_tir(cls.reshape13, (rms_norm113,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv281 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape475, model_layers_7_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv282 = R.call_tir(cls.fused_softmax_cast1, (lv281,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv571 = R.call_tir(cls.top4_softmax, (lv282,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_055: R.Tensor((seq_len, 4), dtype="float16") = lv571[0]
            top4_softmax_155: R.Tensor((seq_len, 4), dtype="int32") = lv571[1]
            lv283 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_155,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape476 = R.call_tir(cls.reshape5, (lv283,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv62_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape476, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv63_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv62_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum31: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv63_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv573 = R.call_tir(cls.get_indices, (cumsum31, top4_softmax_155), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_031: R.Tensor((seq_len * 4,), dtype="int32") = lv573[0]
            get_indices_131: R.Tensor((seq_len * 4,), dtype="int32") = lv573[1]
            lv574 = R.call_tir(cls.get_expert_instance_indptr, (cumsum31,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take32 = R.call_tir(cls.take, (reshape475, get_indices_131), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv575 = R.call_tir(cls.dequantize_group_gemm, (take32, model_layers_7_mlp_moe_gate_up_proj_q_weight3, model_layers_7_mlp_moe_gate_up_proj_q_scale3, lv574), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv284 = R.call_tir(cls.fused_split_silu_multiply, (lv575,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv576 = R.call_tir(cls.dequantize_group_gemm1, (lv284, model_layers_7_mlp_moe_down_proj_q_weight3, model_layers_7_mlp_moe_down_proj_q_scale3, lv574), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv577 = R.call_tir(cls.scatter_output, (lv576, get_indices_031), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape477 = R.call_tir(cls.reshape6, (top4_softmax_055,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape478 = R.call_tir(cls.reshape7, (lv577,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv285 = R.call_tir(cls.fused_multiply1_sum, (reshape478, reshape477), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv64 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_7_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_7_mlp_shared_expert_gate_up_proj_q_scale3, reshape475), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv286 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv64,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv287 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape475, model_layers_7_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv31_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_7_mlp_shared_expert_down_proj_q_weight3, model_layers_7_mlp_shared_expert_down_proj_q_scale3, lv286, lv287, lv285), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape479 = R.call_tir(cls.reshape14, (lv31_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv126 = R.call_tir(cls.fuse_add_norm_prefill, (reshape479, lv125, model_layers_8_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv127: R.Tensor((1, seq_len, 2048), dtype="float16") = lv126[1]
            rms_norm114: R.Tensor((1, seq_len, 2048), dtype="float16") = lv126[0]
            lv32 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_8_self_attn_c_attn_q_weight3, model_layers_8_self_attn_c_attn_q_scale3, rms_norm114, model_layers_8_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape480 = R.call_tir(cls.reshape9, (lv32,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape481 = R.call_tir(cls.reshape10, (reshape480,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv581 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(8), R.prim_value(T.float32(1.0)), reshape481), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape482 = R.call_tir(cls.reshape11, (lv581,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape483 = R.call_tir(cls.reshape12, (reshape482,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv65 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_8_self_attn_o_proj_q_weight3, model_layers_8_self_attn_o_proj_q_scale3, reshape483), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv128 = R.call_tir(cls.fuse_add_norm_prefill, (lv65, lv127, model_layers_8_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv129: R.Tensor((1, seq_len, 2048), dtype="float16") = lv128[1]
            rms_norm115: R.Tensor((1, seq_len, 2048), dtype="float16") = lv128[0]
            reshape484 = R.call_tir(cls.reshape13, (rms_norm115,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv290 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape484, model_layers_8_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv291 = R.call_tir(cls.fused_softmax_cast1, (lv290,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv583 = R.call_tir(cls.top4_softmax, (lv291,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_056: R.Tensor((seq_len, 4), dtype="float16") = lv583[0]
            top4_softmax_156: R.Tensor((seq_len, 4), dtype="int32") = lv583[1]
            lv292 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_156,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape485 = R.call_tir(cls.reshape5, (lv292,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv64_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape485, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv65_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv64_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum32: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv65_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv585 = R.call_tir(cls.get_indices, (cumsum32, top4_softmax_156), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_032: R.Tensor((seq_len * 4,), dtype="int32") = lv585[0]
            get_indices_132: R.Tensor((seq_len * 4,), dtype="int32") = lv585[1]
            lv586 = R.call_tir(cls.get_expert_instance_indptr, (cumsum32,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take33 = R.call_tir(cls.take, (reshape484, get_indices_132), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv587 = R.call_tir(cls.dequantize_group_gemm, (take33, model_layers_8_mlp_moe_gate_up_proj_q_weight3, model_layers_8_mlp_moe_gate_up_proj_q_scale3, lv586), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv293 = R.call_tir(cls.fused_split_silu_multiply, (lv587,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv588 = R.call_tir(cls.dequantize_group_gemm1, (lv293, model_layers_8_mlp_moe_down_proj_q_weight3, model_layers_8_mlp_moe_down_proj_q_scale3, lv586), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv589 = R.call_tir(cls.scatter_output, (lv588, get_indices_032), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape486 = R.call_tir(cls.reshape6, (top4_softmax_056,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape487 = R.call_tir(cls.reshape7, (lv589,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv294 = R.call_tir(cls.fused_multiply1_sum, (reshape487, reshape486), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv66 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_8_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_8_mlp_shared_expert_gate_up_proj_q_scale3, reshape484), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv295 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv66,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv296 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape484, model_layers_8_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv32_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_8_mlp_shared_expert_down_proj_q_weight3, model_layers_8_mlp_shared_expert_down_proj_q_scale3, lv295, lv296, lv294), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape488 = R.call_tir(cls.reshape14, (lv32_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv130 = R.call_tir(cls.fuse_add_norm_prefill, (reshape488, lv129, model_layers_9_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv131: R.Tensor((1, seq_len, 2048), dtype="float16") = lv130[1]
            rms_norm116: R.Tensor((1, seq_len, 2048), dtype="float16") = lv130[0]
            lv33 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_9_self_attn_c_attn_q_weight3, model_layers_9_self_attn_c_attn_q_scale3, rms_norm116, model_layers_9_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape489 = R.call_tir(cls.reshape9, (lv33,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape490 = R.call_tir(cls.reshape10, (reshape489,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv593 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(9), R.prim_value(T.float32(1.0)), reshape490), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape491 = R.call_tir(cls.reshape11, (lv593,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape492 = R.call_tir(cls.reshape12, (reshape491,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv67 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_9_self_attn_o_proj_q_weight3, model_layers_9_self_attn_o_proj_q_scale3, reshape492), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv132 = R.call_tir(cls.fuse_add_norm_prefill, (lv67, lv131, model_layers_9_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv133: R.Tensor((1, seq_len, 2048), dtype="float16") = lv132[1]
            rms_norm117: R.Tensor((1, seq_len, 2048), dtype="float16") = lv132[0]
            reshape493 = R.call_tir(cls.reshape13, (rms_norm117,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv299 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape493, model_layers_9_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv300 = R.call_tir(cls.fused_softmax_cast1, (lv299,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv595 = R.call_tir(cls.top4_softmax, (lv300,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_057: R.Tensor((seq_len, 4), dtype="float16") = lv595[0]
            top4_softmax_157: R.Tensor((seq_len, 4), dtype="int32") = lv595[1]
            lv301 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_157,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape494 = R.call_tir(cls.reshape5, (lv301,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv66_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape494, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv67_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv66_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum33: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv67_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv597 = R.call_tir(cls.get_indices, (cumsum33, top4_softmax_157), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_033: R.Tensor((seq_len * 4,), dtype="int32") = lv597[0]
            get_indices_133: R.Tensor((seq_len * 4,), dtype="int32") = lv597[1]
            lv598 = R.call_tir(cls.get_expert_instance_indptr, (cumsum33,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take34 = R.call_tir(cls.take, (reshape493, get_indices_133), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv599 = R.call_tir(cls.dequantize_group_gemm, (take34, model_layers_9_mlp_moe_gate_up_proj_q_weight3, model_layers_9_mlp_moe_gate_up_proj_q_scale3, lv598), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv302 = R.call_tir(cls.fused_split_silu_multiply, (lv599,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv600 = R.call_tir(cls.dequantize_group_gemm1, (lv302, model_layers_9_mlp_moe_down_proj_q_weight3, model_layers_9_mlp_moe_down_proj_q_scale3, lv598), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv601 = R.call_tir(cls.scatter_output, (lv600, get_indices_033), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape495 = R.call_tir(cls.reshape6, (top4_softmax_057,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape496 = R.call_tir(cls.reshape7, (lv601,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv303 = R.call_tir(cls.fused_multiply1_sum, (reshape496, reshape495), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv68 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_9_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_9_mlp_shared_expert_gate_up_proj_q_scale3, reshape493), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv304 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv68,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv305 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape493, model_layers_9_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv33_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_9_mlp_shared_expert_down_proj_q_weight3, model_layers_9_mlp_shared_expert_down_proj_q_scale3, lv304, lv305, lv303), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape497 = R.call_tir(cls.reshape14, (lv33_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv134 = R.call_tir(cls.fuse_add_norm_prefill, (reshape497, lv133, model_layers_10_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv135: R.Tensor((1, seq_len, 2048), dtype="float16") = lv134[1]
            rms_norm118: R.Tensor((1, seq_len, 2048), dtype="float16") = lv134[0]
            lv34 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_10_self_attn_c_attn_q_weight3, model_layers_10_self_attn_c_attn_q_scale3, rms_norm118, model_layers_10_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape498 = R.call_tir(cls.reshape9, (lv34,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape499 = R.call_tir(cls.reshape10, (reshape498,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv605 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(10), R.prim_value(T.float32(1.0)), reshape499), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape500 = R.call_tir(cls.reshape11, (lv605,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape501 = R.call_tir(cls.reshape12, (reshape500,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv69 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_10_self_attn_o_proj_q_weight3, model_layers_10_self_attn_o_proj_q_scale3, reshape501), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv136 = R.call_tir(cls.fuse_add_norm_prefill, (lv69, lv135, model_layers_10_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv137: R.Tensor((1, seq_len, 2048), dtype="float16") = lv136[1]
            rms_norm119: R.Tensor((1, seq_len, 2048), dtype="float16") = lv136[0]
            reshape502 = R.call_tir(cls.reshape13, (rms_norm119,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv308 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape502, model_layers_10_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv309 = R.call_tir(cls.fused_softmax_cast1, (lv308,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv607 = R.call_tir(cls.top4_softmax, (lv309,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_058: R.Tensor((seq_len, 4), dtype="float16") = lv607[0]
            top4_softmax_158: R.Tensor((seq_len, 4), dtype="int32") = lv607[1]
            lv310 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_158,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape503 = R.call_tir(cls.reshape5, (lv310,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv68_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape503, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv69_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv68_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum34: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv69_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv609 = R.call_tir(cls.get_indices, (cumsum34, top4_softmax_158), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_034: R.Tensor((seq_len * 4,), dtype="int32") = lv609[0]
            get_indices_134: R.Tensor((seq_len * 4,), dtype="int32") = lv609[1]
            lv610 = R.call_tir(cls.get_expert_instance_indptr, (cumsum34,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take35 = R.call_tir(cls.take, (reshape502, get_indices_134), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv611 = R.call_tir(cls.dequantize_group_gemm, (take35, model_layers_10_mlp_moe_gate_up_proj_q_weight3, model_layers_10_mlp_moe_gate_up_proj_q_scale3, lv610), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv311 = R.call_tir(cls.fused_split_silu_multiply, (lv611,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv612 = R.call_tir(cls.dequantize_group_gemm1, (lv311, model_layers_10_mlp_moe_down_proj_q_weight3, model_layers_10_mlp_moe_down_proj_q_scale3, lv610), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv613 = R.call_tir(cls.scatter_output, (lv612, get_indices_034), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape504 = R.call_tir(cls.reshape6, (top4_softmax_058,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape505 = R.call_tir(cls.reshape7, (lv613,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv312 = R.call_tir(cls.fused_multiply1_sum, (reshape505, reshape504), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv70 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_10_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_10_mlp_shared_expert_gate_up_proj_q_scale3, reshape502), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv313 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv70,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv314 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape502, model_layers_10_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv34_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_10_mlp_shared_expert_down_proj_q_weight3, model_layers_10_mlp_shared_expert_down_proj_q_scale3, lv313, lv314, lv312), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape506 = R.call_tir(cls.reshape14, (lv34_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv138 = R.call_tir(cls.fuse_add_norm_prefill, (reshape506, lv137, model_layers_11_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv139: R.Tensor((1, seq_len, 2048), dtype="float16") = lv138[1]
            rms_norm120: R.Tensor((1, seq_len, 2048), dtype="float16") = lv138[0]
            lv35 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_11_self_attn_c_attn_q_weight3, model_layers_11_self_attn_c_attn_q_scale3, rms_norm120, model_layers_11_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape507 = R.call_tir(cls.reshape9, (lv35,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape508 = R.call_tir(cls.reshape10, (reshape507,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv617 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(11), R.prim_value(T.float32(1.0)), reshape508), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape509 = R.call_tir(cls.reshape11, (lv617,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape510 = R.call_tir(cls.reshape12, (reshape509,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv71 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_11_self_attn_o_proj_q_weight3, model_layers_11_self_attn_o_proj_q_scale3, reshape510), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv140 = R.call_tir(cls.fuse_add_norm_prefill, (lv71, lv139, model_layers_11_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv141: R.Tensor((1, seq_len, 2048), dtype="float16") = lv140[1]
            rms_norm121: R.Tensor((1, seq_len, 2048), dtype="float16") = lv140[0]
            reshape511 = R.call_tir(cls.reshape13, (rms_norm121,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv317 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape511, model_layers_11_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv318 = R.call_tir(cls.fused_softmax_cast1, (lv317,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv619 = R.call_tir(cls.top4_softmax, (lv318,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_059: R.Tensor((seq_len, 4), dtype="float16") = lv619[0]
            top4_softmax_159: R.Tensor((seq_len, 4), dtype="int32") = lv619[1]
            lv319 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_159,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape512 = R.call_tir(cls.reshape5, (lv319,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv70_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape512, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv71_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv70_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum35: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv71_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv621 = R.call_tir(cls.get_indices, (cumsum35, top4_softmax_159), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_035: R.Tensor((seq_len * 4,), dtype="int32") = lv621[0]
            get_indices_135: R.Tensor((seq_len * 4,), dtype="int32") = lv621[1]
            lv622 = R.call_tir(cls.get_expert_instance_indptr, (cumsum35,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take36 = R.call_tir(cls.take, (reshape511, get_indices_135), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv623 = R.call_tir(cls.dequantize_group_gemm, (take36, model_layers_11_mlp_moe_gate_up_proj_q_weight3, model_layers_11_mlp_moe_gate_up_proj_q_scale3, lv622), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv320 = R.call_tir(cls.fused_split_silu_multiply, (lv623,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv624 = R.call_tir(cls.dequantize_group_gemm1, (lv320, model_layers_11_mlp_moe_down_proj_q_weight3, model_layers_11_mlp_moe_down_proj_q_scale3, lv622), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv625 = R.call_tir(cls.scatter_output, (lv624, get_indices_035), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape513 = R.call_tir(cls.reshape6, (top4_softmax_059,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape514 = R.call_tir(cls.reshape7, (lv625,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv321 = R.call_tir(cls.fused_multiply1_sum, (reshape514, reshape513), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv72 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_11_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_11_mlp_shared_expert_gate_up_proj_q_scale3, reshape511), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv322 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv72,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv323 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape511, model_layers_11_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv35_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_11_mlp_shared_expert_down_proj_q_weight3, model_layers_11_mlp_shared_expert_down_proj_q_scale3, lv322, lv323, lv321), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape515 = R.call_tir(cls.reshape14, (lv35_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv142 = R.call_tir(cls.fuse_add_norm_prefill, (reshape515, lv141, model_layers_12_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv143: R.Tensor((1, seq_len, 2048), dtype="float16") = lv142[1]
            rms_norm122: R.Tensor((1, seq_len, 2048), dtype="float16") = lv142[0]
            lv36 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_12_self_attn_c_attn_q_weight3, model_layers_12_self_attn_c_attn_q_scale3, rms_norm122, model_layers_12_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape516 = R.call_tir(cls.reshape9, (lv36,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape517 = R.call_tir(cls.reshape10, (reshape516,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv629 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(12), R.prim_value(T.float32(1.0)), reshape517), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape518 = R.call_tir(cls.reshape11, (lv629,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape519 = R.call_tir(cls.reshape12, (reshape518,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv73 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_12_self_attn_o_proj_q_weight3, model_layers_12_self_attn_o_proj_q_scale3, reshape519), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv144 = R.call_tir(cls.fuse_add_norm_prefill, (lv73, lv143, model_layers_12_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv145: R.Tensor((1, seq_len, 2048), dtype="float16") = lv144[1]
            rms_norm123: R.Tensor((1, seq_len, 2048), dtype="float16") = lv144[0]
            reshape520 = R.call_tir(cls.reshape13, (rms_norm123,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv326 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape520, model_layers_12_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv327 = R.call_tir(cls.fused_softmax_cast1, (lv326,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv631 = R.call_tir(cls.top4_softmax, (lv327,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_060: R.Tensor((seq_len, 4), dtype="float16") = lv631[0]
            top4_softmax_160: R.Tensor((seq_len, 4), dtype="int32") = lv631[1]
            lv328 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_160,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape521 = R.call_tir(cls.reshape5, (lv328,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv72_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape521, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv73_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv72_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum36: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv73_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv633 = R.call_tir(cls.get_indices, (cumsum36, top4_softmax_160), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_036: R.Tensor((seq_len * 4,), dtype="int32") = lv633[0]
            get_indices_136: R.Tensor((seq_len * 4,), dtype="int32") = lv633[1]
            lv634 = R.call_tir(cls.get_expert_instance_indptr, (cumsum36,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take37 = R.call_tir(cls.take, (reshape520, get_indices_136), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv635 = R.call_tir(cls.dequantize_group_gemm, (take37, model_layers_12_mlp_moe_gate_up_proj_q_weight3, model_layers_12_mlp_moe_gate_up_proj_q_scale3, lv634), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv329 = R.call_tir(cls.fused_split_silu_multiply, (lv635,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv636 = R.call_tir(cls.dequantize_group_gemm1, (lv329, model_layers_12_mlp_moe_down_proj_q_weight3, model_layers_12_mlp_moe_down_proj_q_scale3, lv634), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv637 = R.call_tir(cls.scatter_output, (lv636, get_indices_036), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape522 = R.call_tir(cls.reshape6, (top4_softmax_060,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape523 = R.call_tir(cls.reshape7, (lv637,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv330 = R.call_tir(cls.fused_multiply1_sum, (reshape523, reshape522), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv74 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_12_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_12_mlp_shared_expert_gate_up_proj_q_scale3, reshape520), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv331 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv74,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv332 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape520, model_layers_12_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv36_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_12_mlp_shared_expert_down_proj_q_weight3, model_layers_12_mlp_shared_expert_down_proj_q_scale3, lv331, lv332, lv330), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape524 = R.call_tir(cls.reshape14, (lv36_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv146 = R.call_tir(cls.fuse_add_norm_prefill, (reshape524, lv145, model_layers_13_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv147: R.Tensor((1, seq_len, 2048), dtype="float16") = lv146[1]
            rms_norm124: R.Tensor((1, seq_len, 2048), dtype="float16") = lv146[0]
            lv37 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_13_self_attn_c_attn_q_weight3, model_layers_13_self_attn_c_attn_q_scale3, rms_norm124, model_layers_13_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape525 = R.call_tir(cls.reshape9, (lv37,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape526 = R.call_tir(cls.reshape10, (reshape525,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv641 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(13), R.prim_value(T.float32(1.0)), reshape526), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape527 = R.call_tir(cls.reshape11, (lv641,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape528 = R.call_tir(cls.reshape12, (reshape527,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv75 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_13_self_attn_o_proj_q_weight3, model_layers_13_self_attn_o_proj_q_scale3, reshape528), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv148 = R.call_tir(cls.fuse_add_norm_prefill, (lv75, lv147, model_layers_13_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv149: R.Tensor((1, seq_len, 2048), dtype="float16") = lv148[1]
            rms_norm125: R.Tensor((1, seq_len, 2048), dtype="float16") = lv148[0]
            reshape529 = R.call_tir(cls.reshape13, (rms_norm125,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv335 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape529, model_layers_13_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv336 = R.call_tir(cls.fused_softmax_cast1, (lv335,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv643 = R.call_tir(cls.top4_softmax, (lv336,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_061: R.Tensor((seq_len, 4), dtype="float16") = lv643[0]
            top4_softmax_161: R.Tensor((seq_len, 4), dtype="int32") = lv643[1]
            lv337 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_161,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape530 = R.call_tir(cls.reshape5, (lv337,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv74_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape530, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv75_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv74_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum37: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv75_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv645 = R.call_tir(cls.get_indices, (cumsum37, top4_softmax_161), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_037: R.Tensor((seq_len * 4,), dtype="int32") = lv645[0]
            get_indices_137: R.Tensor((seq_len * 4,), dtype="int32") = lv645[1]
            lv646 = R.call_tir(cls.get_expert_instance_indptr, (cumsum37,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take38 = R.call_tir(cls.take, (reshape529, get_indices_137), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv647 = R.call_tir(cls.dequantize_group_gemm, (take38, model_layers_13_mlp_moe_gate_up_proj_q_weight3, model_layers_13_mlp_moe_gate_up_proj_q_scale3, lv646), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv338 = R.call_tir(cls.fused_split_silu_multiply, (lv647,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv648 = R.call_tir(cls.dequantize_group_gemm1, (lv338, model_layers_13_mlp_moe_down_proj_q_weight3, model_layers_13_mlp_moe_down_proj_q_scale3, lv646), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv649 = R.call_tir(cls.scatter_output, (lv648, get_indices_037), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape531 = R.call_tir(cls.reshape6, (top4_softmax_061,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape532 = R.call_tir(cls.reshape7, (lv649,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv339 = R.call_tir(cls.fused_multiply1_sum, (reshape532, reshape531), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv76 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_13_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_13_mlp_shared_expert_gate_up_proj_q_scale3, reshape529), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv340 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv76,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv341 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape529, model_layers_13_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv37_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_13_mlp_shared_expert_down_proj_q_weight3, model_layers_13_mlp_shared_expert_down_proj_q_scale3, lv340, lv341, lv339), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape533 = R.call_tir(cls.reshape14, (lv37_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv150 = R.call_tir(cls.fuse_add_norm_prefill, (reshape533, lv149, model_layers_14_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv151: R.Tensor((1, seq_len, 2048), dtype="float16") = lv150[1]
            rms_norm126: R.Tensor((1, seq_len, 2048), dtype="float16") = lv150[0]
            lv38 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_14_self_attn_c_attn_q_weight3, model_layers_14_self_attn_c_attn_q_scale3, rms_norm126, model_layers_14_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape534 = R.call_tir(cls.reshape9, (lv38,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape535 = R.call_tir(cls.reshape10, (reshape534,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv653 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(14), R.prim_value(T.float32(1.0)), reshape535), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape536 = R.call_tir(cls.reshape11, (lv653,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape537 = R.call_tir(cls.reshape12, (reshape536,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv77 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_14_self_attn_o_proj_q_weight3, model_layers_14_self_attn_o_proj_q_scale3, reshape537), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv152 = R.call_tir(cls.fuse_add_norm_prefill, (lv77, lv151, model_layers_14_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv153: R.Tensor((1, seq_len, 2048), dtype="float16") = lv152[1]
            rms_norm127: R.Tensor((1, seq_len, 2048), dtype="float16") = lv152[0]
            reshape538 = R.call_tir(cls.reshape13, (rms_norm127,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv344 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape538, model_layers_14_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv345 = R.call_tir(cls.fused_softmax_cast1, (lv344,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv655 = R.call_tir(cls.top4_softmax, (lv345,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_062: R.Tensor((seq_len, 4), dtype="float16") = lv655[0]
            top4_softmax_162: R.Tensor((seq_len, 4), dtype="int32") = lv655[1]
            lv346 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_162,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape539 = R.call_tir(cls.reshape5, (lv346,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv76_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape539, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv77_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv76_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum38: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv77_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv657 = R.call_tir(cls.get_indices, (cumsum38, top4_softmax_162), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_038: R.Tensor((seq_len * 4,), dtype="int32") = lv657[0]
            get_indices_138: R.Tensor((seq_len * 4,), dtype="int32") = lv657[1]
            lv658 = R.call_tir(cls.get_expert_instance_indptr, (cumsum38,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take39 = R.call_tir(cls.take, (reshape538, get_indices_138), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv659 = R.call_tir(cls.dequantize_group_gemm, (take39, model_layers_14_mlp_moe_gate_up_proj_q_weight3, model_layers_14_mlp_moe_gate_up_proj_q_scale3, lv658), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv347 = R.call_tir(cls.fused_split_silu_multiply, (lv659,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv660 = R.call_tir(cls.dequantize_group_gemm1, (lv347, model_layers_14_mlp_moe_down_proj_q_weight3, model_layers_14_mlp_moe_down_proj_q_scale3, lv658), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv661 = R.call_tir(cls.scatter_output, (lv660, get_indices_038), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape540 = R.call_tir(cls.reshape6, (top4_softmax_062,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape541 = R.call_tir(cls.reshape7, (lv661,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv348 = R.call_tir(cls.fused_multiply1_sum, (reshape541, reshape540), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv78 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_14_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_14_mlp_shared_expert_gate_up_proj_q_scale3, reshape538), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv349 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv78,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv350 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape538, model_layers_14_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv38_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_14_mlp_shared_expert_down_proj_q_weight3, model_layers_14_mlp_shared_expert_down_proj_q_scale3, lv349, lv350, lv348), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape542 = R.call_tir(cls.reshape14, (lv38_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv154 = R.call_tir(cls.fuse_add_norm_prefill, (reshape542, lv153, model_layers_15_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv155: R.Tensor((1, seq_len, 2048), dtype="float16") = lv154[1]
            rms_norm128: R.Tensor((1, seq_len, 2048), dtype="float16") = lv154[0]
            lv39 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_15_self_attn_c_attn_q_weight3, model_layers_15_self_attn_c_attn_q_scale3, rms_norm128, model_layers_15_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape543 = R.call_tir(cls.reshape9, (lv39,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape544 = R.call_tir(cls.reshape10, (reshape543,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv665 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(15), R.prim_value(T.float32(1.0)), reshape544), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape545 = R.call_tir(cls.reshape11, (lv665,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape546 = R.call_tir(cls.reshape12, (reshape545,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv79 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_15_self_attn_o_proj_q_weight3, model_layers_15_self_attn_o_proj_q_scale3, reshape546), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv156 = R.call_tir(cls.fuse_add_norm_prefill, (lv79, lv155, model_layers_15_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv157: R.Tensor((1, seq_len, 2048), dtype="float16") = lv156[1]
            rms_norm129: R.Tensor((1, seq_len, 2048), dtype="float16") = lv156[0]
            reshape547 = R.call_tir(cls.reshape13, (rms_norm129,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv353 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape547, model_layers_15_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv354 = R.call_tir(cls.fused_softmax_cast1, (lv353,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv667 = R.call_tir(cls.top4_softmax, (lv354,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_063: R.Tensor((seq_len, 4), dtype="float16") = lv667[0]
            top4_softmax_163: R.Tensor((seq_len, 4), dtype="int32") = lv667[1]
            lv355 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_163,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape548 = R.call_tir(cls.reshape5, (lv355,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv78_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape548, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv79_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv78_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum39: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv79_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv669 = R.call_tir(cls.get_indices, (cumsum39, top4_softmax_163), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_039: R.Tensor((seq_len * 4,), dtype="int32") = lv669[0]
            get_indices_139: R.Tensor((seq_len * 4,), dtype="int32") = lv669[1]
            lv670 = R.call_tir(cls.get_expert_instance_indptr, (cumsum39,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take40 = R.call_tir(cls.take, (reshape547, get_indices_139), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv671 = R.call_tir(cls.dequantize_group_gemm, (take40, model_layers_15_mlp_moe_gate_up_proj_q_weight3, model_layers_15_mlp_moe_gate_up_proj_q_scale3, lv670), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv356 = R.call_tir(cls.fused_split_silu_multiply, (lv671,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv672 = R.call_tir(cls.dequantize_group_gemm1, (lv356, model_layers_15_mlp_moe_down_proj_q_weight3, model_layers_15_mlp_moe_down_proj_q_scale3, lv670), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv673 = R.call_tir(cls.scatter_output, (lv672, get_indices_039), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape549 = R.call_tir(cls.reshape6, (top4_softmax_063,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape550 = R.call_tir(cls.reshape7, (lv673,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv357 = R.call_tir(cls.fused_multiply1_sum, (reshape550, reshape549), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv80 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_15_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_15_mlp_shared_expert_gate_up_proj_q_scale3, reshape547), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv358 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv80,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv359 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape547, model_layers_15_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv39_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_15_mlp_shared_expert_down_proj_q_weight3, model_layers_15_mlp_shared_expert_down_proj_q_scale3, lv358, lv359, lv357), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape551 = R.call_tir(cls.reshape14, (lv39_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv158 = R.call_tir(cls.fuse_add_norm_prefill, (reshape551, lv157, model_layers_16_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv159: R.Tensor((1, seq_len, 2048), dtype="float16") = lv158[1]
            rms_norm130: R.Tensor((1, seq_len, 2048), dtype="float16") = lv158[0]
            lv40 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_16_self_attn_c_attn_q_weight3, model_layers_16_self_attn_c_attn_q_scale3, rms_norm130, model_layers_16_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape552 = R.call_tir(cls.reshape9, (lv40,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape553 = R.call_tir(cls.reshape10, (reshape552,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv677 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(16), R.prim_value(T.float32(1.0)), reshape553), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape554 = R.call_tir(cls.reshape11, (lv677,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape555 = R.call_tir(cls.reshape12, (reshape554,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv81 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_16_self_attn_o_proj_q_weight3, model_layers_16_self_attn_o_proj_q_scale3, reshape555), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv160 = R.call_tir(cls.fuse_add_norm_prefill, (lv81, lv159, model_layers_16_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv161: R.Tensor((1, seq_len, 2048), dtype="float16") = lv160[1]
            rms_norm131: R.Tensor((1, seq_len, 2048), dtype="float16") = lv160[0]
            reshape556 = R.call_tir(cls.reshape13, (rms_norm131,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv362 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape556, model_layers_16_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv363 = R.call_tir(cls.fused_softmax_cast1, (lv362,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv679 = R.call_tir(cls.top4_softmax, (lv363,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_064: R.Tensor((seq_len, 4), dtype="float16") = lv679[0]
            top4_softmax_164: R.Tensor((seq_len, 4), dtype="int32") = lv679[1]
            lv364 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_164,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape557 = R.call_tir(cls.reshape5, (lv364,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv80_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape557, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv81_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv80_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum40: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv81_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv681 = R.call_tir(cls.get_indices, (cumsum40, top4_softmax_164), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_040: R.Tensor((seq_len * 4,), dtype="int32") = lv681[0]
            get_indices_140: R.Tensor((seq_len * 4,), dtype="int32") = lv681[1]
            lv682 = R.call_tir(cls.get_expert_instance_indptr, (cumsum40,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take41 = R.call_tir(cls.take, (reshape556, get_indices_140), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv683 = R.call_tir(cls.dequantize_group_gemm, (take41, model_layers_16_mlp_moe_gate_up_proj_q_weight3, model_layers_16_mlp_moe_gate_up_proj_q_scale3, lv682), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv365 = R.call_tir(cls.fused_split_silu_multiply, (lv683,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv684 = R.call_tir(cls.dequantize_group_gemm1, (lv365, model_layers_16_mlp_moe_down_proj_q_weight3, model_layers_16_mlp_moe_down_proj_q_scale3, lv682), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv685 = R.call_tir(cls.scatter_output, (lv684, get_indices_040), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape558 = R.call_tir(cls.reshape6, (top4_softmax_064,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape559 = R.call_tir(cls.reshape7, (lv685,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv366 = R.call_tir(cls.fused_multiply1_sum, (reshape559, reshape558), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv82 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_16_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_16_mlp_shared_expert_gate_up_proj_q_scale3, reshape556), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv367 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv82,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv368 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape556, model_layers_16_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv40_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_16_mlp_shared_expert_down_proj_q_weight3, model_layers_16_mlp_shared_expert_down_proj_q_scale3, lv367, lv368, lv366), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape560 = R.call_tir(cls.reshape14, (lv40_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv162 = R.call_tir(cls.fuse_add_norm_prefill, (reshape560, lv161, model_layers_17_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv163: R.Tensor((1, seq_len, 2048), dtype="float16") = lv162[1]
            rms_norm132: R.Tensor((1, seq_len, 2048), dtype="float16") = lv162[0]
            lv41 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_17_self_attn_c_attn_q_weight3, model_layers_17_self_attn_c_attn_q_scale3, rms_norm132, model_layers_17_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape561 = R.call_tir(cls.reshape9, (lv41,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape562 = R.call_tir(cls.reshape10, (reshape561,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv689 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(17), R.prim_value(T.float32(1.0)), reshape562), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape563 = R.call_tir(cls.reshape11, (lv689,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape564 = R.call_tir(cls.reshape12, (reshape563,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv83 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_17_self_attn_o_proj_q_weight3, model_layers_17_self_attn_o_proj_q_scale3, reshape564), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv164 = R.call_tir(cls.fuse_add_norm_prefill, (lv83, lv163, model_layers_17_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv165: R.Tensor((1, seq_len, 2048), dtype="float16") = lv164[1]
            rms_norm133: R.Tensor((1, seq_len, 2048), dtype="float16") = lv164[0]
            reshape565 = R.call_tir(cls.reshape13, (rms_norm133,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv371 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape565, model_layers_17_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv372 = R.call_tir(cls.fused_softmax_cast1, (lv371,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv691 = R.call_tir(cls.top4_softmax, (lv372,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_065: R.Tensor((seq_len, 4), dtype="float16") = lv691[0]
            top4_softmax_165: R.Tensor((seq_len, 4), dtype="int32") = lv691[1]
            lv373 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_165,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape566 = R.call_tir(cls.reshape5, (lv373,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv82_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape566, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv83_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv82_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum41: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv83_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv693 = R.call_tir(cls.get_indices, (cumsum41, top4_softmax_165), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_041: R.Tensor((seq_len * 4,), dtype="int32") = lv693[0]
            get_indices_141: R.Tensor((seq_len * 4,), dtype="int32") = lv693[1]
            lv694 = R.call_tir(cls.get_expert_instance_indptr, (cumsum41,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take42 = R.call_tir(cls.take, (reshape565, get_indices_141), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv695 = R.call_tir(cls.dequantize_group_gemm, (take42, model_layers_17_mlp_moe_gate_up_proj_q_weight3, model_layers_17_mlp_moe_gate_up_proj_q_scale3, lv694), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv374 = R.call_tir(cls.fused_split_silu_multiply, (lv695,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv696 = R.call_tir(cls.dequantize_group_gemm1, (lv374, model_layers_17_mlp_moe_down_proj_q_weight3, model_layers_17_mlp_moe_down_proj_q_scale3, lv694), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv697 = R.call_tir(cls.scatter_output, (lv696, get_indices_041), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape567 = R.call_tir(cls.reshape6, (top4_softmax_065,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape568 = R.call_tir(cls.reshape7, (lv697,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv375 = R.call_tir(cls.fused_multiply1_sum, (reshape568, reshape567), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv84 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_17_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_17_mlp_shared_expert_gate_up_proj_q_scale3, reshape565), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv376 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv84,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv377 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape565, model_layers_17_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv41_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_17_mlp_shared_expert_down_proj_q_weight3, model_layers_17_mlp_shared_expert_down_proj_q_scale3, lv376, lv377, lv375), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape569 = R.call_tir(cls.reshape14, (lv41_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv166 = R.call_tir(cls.fuse_add_norm_prefill, (reshape569, lv165, model_layers_18_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv167: R.Tensor((1, seq_len, 2048), dtype="float16") = lv166[1]
            rms_norm134: R.Tensor((1, seq_len, 2048), dtype="float16") = lv166[0]
            lv42 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_18_self_attn_c_attn_q_weight3, model_layers_18_self_attn_c_attn_q_scale3, rms_norm134, model_layers_18_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape570 = R.call_tir(cls.reshape9, (lv42,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape571 = R.call_tir(cls.reshape10, (reshape570,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv701 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(18), R.prim_value(T.float32(1.0)), reshape571), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape572 = R.call_tir(cls.reshape11, (lv701,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape573 = R.call_tir(cls.reshape12, (reshape572,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv85 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_18_self_attn_o_proj_q_weight3, model_layers_18_self_attn_o_proj_q_scale3, reshape573), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv168 = R.call_tir(cls.fuse_add_norm_prefill, (lv85, lv167, model_layers_18_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv169: R.Tensor((1, seq_len, 2048), dtype="float16") = lv168[1]
            rms_norm135: R.Tensor((1, seq_len, 2048), dtype="float16") = lv168[0]
            reshape574 = R.call_tir(cls.reshape13, (rms_norm135,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv380 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape574, model_layers_18_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv381 = R.call_tir(cls.fused_softmax_cast1, (lv380,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv703 = R.call_tir(cls.top4_softmax, (lv381,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_066: R.Tensor((seq_len, 4), dtype="float16") = lv703[0]
            top4_softmax_166: R.Tensor((seq_len, 4), dtype="int32") = lv703[1]
            lv382 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_166,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape575 = R.call_tir(cls.reshape5, (lv382,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv84_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape575, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv85_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv84_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum42: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv85_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv705 = R.call_tir(cls.get_indices, (cumsum42, top4_softmax_166), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_042: R.Tensor((seq_len * 4,), dtype="int32") = lv705[0]
            get_indices_142: R.Tensor((seq_len * 4,), dtype="int32") = lv705[1]
            lv706 = R.call_tir(cls.get_expert_instance_indptr, (cumsum42,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take43 = R.call_tir(cls.take, (reshape574, get_indices_142), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv707 = R.call_tir(cls.dequantize_group_gemm, (take43, model_layers_18_mlp_moe_gate_up_proj_q_weight3, model_layers_18_mlp_moe_gate_up_proj_q_scale3, lv706), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv383 = R.call_tir(cls.fused_split_silu_multiply, (lv707,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv708 = R.call_tir(cls.dequantize_group_gemm1, (lv383, model_layers_18_mlp_moe_down_proj_q_weight3, model_layers_18_mlp_moe_down_proj_q_scale3, lv706), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv709 = R.call_tir(cls.scatter_output, (lv708, get_indices_042), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape576 = R.call_tir(cls.reshape6, (top4_softmax_066,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape577 = R.call_tir(cls.reshape7, (lv709,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv384 = R.call_tir(cls.fused_multiply1_sum, (reshape577, reshape576), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv86 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_18_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_18_mlp_shared_expert_gate_up_proj_q_scale3, reshape574), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv385 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv86,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv386 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape574, model_layers_18_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv42_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_18_mlp_shared_expert_down_proj_q_weight3, model_layers_18_mlp_shared_expert_down_proj_q_scale3, lv385, lv386, lv384), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape578 = R.call_tir(cls.reshape14, (lv42_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv170 = R.call_tir(cls.fuse_add_norm_prefill, (reshape578, lv169, model_layers_19_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv171: R.Tensor((1, seq_len, 2048), dtype="float16") = lv170[1]
            rms_norm136: R.Tensor((1, seq_len, 2048), dtype="float16") = lv170[0]
            lv43 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_19_self_attn_c_attn_q_weight3, model_layers_19_self_attn_c_attn_q_scale3, rms_norm136, model_layers_19_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape579 = R.call_tir(cls.reshape9, (lv43,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape580 = R.call_tir(cls.reshape10, (reshape579,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv713 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(19), R.prim_value(T.float32(1.0)), reshape580), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape581 = R.call_tir(cls.reshape11, (lv713,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape582 = R.call_tir(cls.reshape12, (reshape581,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv87 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_19_self_attn_o_proj_q_weight3, model_layers_19_self_attn_o_proj_q_scale3, reshape582), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv172 = R.call_tir(cls.fuse_add_norm_prefill, (lv87, lv171, model_layers_19_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv173: R.Tensor((1, seq_len, 2048), dtype="float16") = lv172[1]
            rms_norm137: R.Tensor((1, seq_len, 2048), dtype="float16") = lv172[0]
            reshape583 = R.call_tir(cls.reshape13, (rms_norm137,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv389 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape583, model_layers_19_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv390 = R.call_tir(cls.fused_softmax_cast1, (lv389,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv715 = R.call_tir(cls.top4_softmax, (lv390,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_067: R.Tensor((seq_len, 4), dtype="float16") = lv715[0]
            top4_softmax_167: R.Tensor((seq_len, 4), dtype="int32") = lv715[1]
            lv391 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_167,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape584 = R.call_tir(cls.reshape5, (lv391,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv86_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape584, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv87_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv86_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum43: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv87_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv717 = R.call_tir(cls.get_indices, (cumsum43, top4_softmax_167), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_043: R.Tensor((seq_len * 4,), dtype="int32") = lv717[0]
            get_indices_143: R.Tensor((seq_len * 4,), dtype="int32") = lv717[1]
            lv718 = R.call_tir(cls.get_expert_instance_indptr, (cumsum43,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take44 = R.call_tir(cls.take, (reshape583, get_indices_143), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv719 = R.call_tir(cls.dequantize_group_gemm, (take44, model_layers_19_mlp_moe_gate_up_proj_q_weight3, model_layers_19_mlp_moe_gate_up_proj_q_scale3, lv718), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv392 = R.call_tir(cls.fused_split_silu_multiply, (lv719,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv720 = R.call_tir(cls.dequantize_group_gemm1, (lv392, model_layers_19_mlp_moe_down_proj_q_weight3, model_layers_19_mlp_moe_down_proj_q_scale3, lv718), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv721 = R.call_tir(cls.scatter_output, (lv720, get_indices_043), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape585 = R.call_tir(cls.reshape6, (top4_softmax_067,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape586 = R.call_tir(cls.reshape7, (lv721,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv393 = R.call_tir(cls.fused_multiply1_sum, (reshape586, reshape585), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv88 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_19_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_19_mlp_shared_expert_gate_up_proj_q_scale3, reshape583), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv394 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv88,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv395 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape583, model_layers_19_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv43_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_19_mlp_shared_expert_down_proj_q_weight3, model_layers_19_mlp_shared_expert_down_proj_q_scale3, lv394, lv395, lv393), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape587 = R.call_tir(cls.reshape14, (lv43_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv174 = R.call_tir(cls.fuse_add_norm_prefill, (reshape587, lv173, model_layers_20_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv175: R.Tensor((1, seq_len, 2048), dtype="float16") = lv174[1]
            rms_norm138: R.Tensor((1, seq_len, 2048), dtype="float16") = lv174[0]
            lv44 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_20_self_attn_c_attn_q_weight3, model_layers_20_self_attn_c_attn_q_scale3, rms_norm138, model_layers_20_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape588 = R.call_tir(cls.reshape9, (lv44,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape589 = R.call_tir(cls.reshape10, (reshape588,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv725 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(20), R.prim_value(T.float32(1.0)), reshape589), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape590 = R.call_tir(cls.reshape11, (lv725,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape591 = R.call_tir(cls.reshape12, (reshape590,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv89 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_20_self_attn_o_proj_q_weight3, model_layers_20_self_attn_o_proj_q_scale3, reshape591), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv176 = R.call_tir(cls.fuse_add_norm_prefill, (lv89, lv175, model_layers_20_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv177: R.Tensor((1, seq_len, 2048), dtype="float16") = lv176[1]
            rms_norm139: R.Tensor((1, seq_len, 2048), dtype="float16") = lv176[0]
            reshape592 = R.call_tir(cls.reshape13, (rms_norm139,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv398 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape592, model_layers_20_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv399 = R.call_tir(cls.fused_softmax_cast1, (lv398,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv727 = R.call_tir(cls.top4_softmax, (lv399,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_068: R.Tensor((seq_len, 4), dtype="float16") = lv727[0]
            top4_softmax_168: R.Tensor((seq_len, 4), dtype="int32") = lv727[1]
            lv400 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_168,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape593 = R.call_tir(cls.reshape5, (lv400,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv88_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape593, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv89_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv88_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum44: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv89_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv729 = R.call_tir(cls.get_indices, (cumsum44, top4_softmax_168), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_044: R.Tensor((seq_len * 4,), dtype="int32") = lv729[0]
            get_indices_144: R.Tensor((seq_len * 4,), dtype="int32") = lv729[1]
            lv730 = R.call_tir(cls.get_expert_instance_indptr, (cumsum44,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take45 = R.call_tir(cls.take, (reshape592, get_indices_144), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv731 = R.call_tir(cls.dequantize_group_gemm, (take45, model_layers_20_mlp_moe_gate_up_proj_q_weight3, model_layers_20_mlp_moe_gate_up_proj_q_scale3, lv730), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv401 = R.call_tir(cls.fused_split_silu_multiply, (lv731,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv732 = R.call_tir(cls.dequantize_group_gemm1, (lv401, model_layers_20_mlp_moe_down_proj_q_weight3, model_layers_20_mlp_moe_down_proj_q_scale3, lv730), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv733 = R.call_tir(cls.scatter_output, (lv732, get_indices_044), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape594 = R.call_tir(cls.reshape6, (top4_softmax_068,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape595 = R.call_tir(cls.reshape7, (lv733,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv402 = R.call_tir(cls.fused_multiply1_sum, (reshape595, reshape594), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv90 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_20_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_20_mlp_shared_expert_gate_up_proj_q_scale3, reshape592), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv403 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv90,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv404 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape592, model_layers_20_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv44_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_20_mlp_shared_expert_down_proj_q_weight3, model_layers_20_mlp_shared_expert_down_proj_q_scale3, lv403, lv404, lv402), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape596 = R.call_tir(cls.reshape14, (lv44_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv178 = R.call_tir(cls.fuse_add_norm_prefill, (reshape596, lv177, model_layers_21_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv179: R.Tensor((1, seq_len, 2048), dtype="float16") = lv178[1]
            rms_norm140: R.Tensor((1, seq_len, 2048), dtype="float16") = lv178[0]
            lv45 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_21_self_attn_c_attn_q_weight3, model_layers_21_self_attn_c_attn_q_scale3, rms_norm140, model_layers_21_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape597 = R.call_tir(cls.reshape9, (lv45,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape598 = R.call_tir(cls.reshape10, (reshape597,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv737 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(21), R.prim_value(T.float32(1.0)), reshape598), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape599 = R.call_tir(cls.reshape11, (lv737,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape600 = R.call_tir(cls.reshape12, (reshape599,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv91 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_21_self_attn_o_proj_q_weight3, model_layers_21_self_attn_o_proj_q_scale3, reshape600), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv180 = R.call_tir(cls.fuse_add_norm_prefill, (lv91, lv179, model_layers_21_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv181: R.Tensor((1, seq_len, 2048), dtype="float16") = lv180[1]
            rms_norm141: R.Tensor((1, seq_len, 2048), dtype="float16") = lv180[0]
            reshape601 = R.call_tir(cls.reshape13, (rms_norm141,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv407 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape601, model_layers_21_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv408 = R.call_tir(cls.fused_softmax_cast1, (lv407,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv739 = R.call_tir(cls.top4_softmax, (lv408,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_069: R.Tensor((seq_len, 4), dtype="float16") = lv739[0]
            top4_softmax_169: R.Tensor((seq_len, 4), dtype="int32") = lv739[1]
            lv409 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_169,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape602 = R.call_tir(cls.reshape5, (lv409,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv90_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape602, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv91_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv90_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum45: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv91_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv741 = R.call_tir(cls.get_indices, (cumsum45, top4_softmax_169), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_045: R.Tensor((seq_len * 4,), dtype="int32") = lv741[0]
            get_indices_145: R.Tensor((seq_len * 4,), dtype="int32") = lv741[1]
            lv742 = R.call_tir(cls.get_expert_instance_indptr, (cumsum45,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take46 = R.call_tir(cls.take, (reshape601, get_indices_145), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv743 = R.call_tir(cls.dequantize_group_gemm, (take46, model_layers_21_mlp_moe_gate_up_proj_q_weight3, model_layers_21_mlp_moe_gate_up_proj_q_scale3, lv742), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv410 = R.call_tir(cls.fused_split_silu_multiply, (lv743,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv744 = R.call_tir(cls.dequantize_group_gemm1, (lv410, model_layers_21_mlp_moe_down_proj_q_weight3, model_layers_21_mlp_moe_down_proj_q_scale3, lv742), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv745 = R.call_tir(cls.scatter_output, (lv744, get_indices_045), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape603 = R.call_tir(cls.reshape6, (top4_softmax_069,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape604 = R.call_tir(cls.reshape7, (lv745,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv411 = R.call_tir(cls.fused_multiply1_sum, (reshape604, reshape603), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv92 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_21_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_21_mlp_shared_expert_gate_up_proj_q_scale3, reshape601), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv412 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv92,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv413 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape601, model_layers_21_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv45_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_21_mlp_shared_expert_down_proj_q_weight3, model_layers_21_mlp_shared_expert_down_proj_q_scale3, lv412, lv413, lv411), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape605 = R.call_tir(cls.reshape14, (lv45_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv182 = R.call_tir(cls.fuse_add_norm_prefill, (reshape605, lv181, model_layers_22_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv183: R.Tensor((1, seq_len, 2048), dtype="float16") = lv182[1]
            rms_norm142: R.Tensor((1, seq_len, 2048), dtype="float16") = lv182[0]
            lv46 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_22_self_attn_c_attn_q_weight3, model_layers_22_self_attn_c_attn_q_scale3, rms_norm142, model_layers_22_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape606 = R.call_tir(cls.reshape9, (lv46,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape607 = R.call_tir(cls.reshape10, (reshape606,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv749 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(22), R.prim_value(T.float32(1.0)), reshape607), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape608 = R.call_tir(cls.reshape11, (lv749,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape609 = R.call_tir(cls.reshape12, (reshape608,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv93 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_22_self_attn_o_proj_q_weight3, model_layers_22_self_attn_o_proj_q_scale3, reshape609), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv184 = R.call_tir(cls.fuse_add_norm_prefill, (lv93, lv183, model_layers_22_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv185: R.Tensor((1, seq_len, 2048), dtype="float16") = lv184[1]
            rms_norm143: R.Tensor((1, seq_len, 2048), dtype="float16") = lv184[0]
            reshape610 = R.call_tir(cls.reshape13, (rms_norm143,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv416 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape610, model_layers_22_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv417 = R.call_tir(cls.fused_softmax_cast1, (lv416,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv751 = R.call_tir(cls.top4_softmax, (lv417,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_070: R.Tensor((seq_len, 4), dtype="float16") = lv751[0]
            top4_softmax_170: R.Tensor((seq_len, 4), dtype="int32") = lv751[1]
            lv418 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_170,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape611 = R.call_tir(cls.reshape5, (lv418,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv92_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape611, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv93_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv92_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum46: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv93_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv753 = R.call_tir(cls.get_indices, (cumsum46, top4_softmax_170), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_046: R.Tensor((seq_len * 4,), dtype="int32") = lv753[0]
            get_indices_146: R.Tensor((seq_len * 4,), dtype="int32") = lv753[1]
            lv754 = R.call_tir(cls.get_expert_instance_indptr, (cumsum46,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take47 = R.call_tir(cls.take, (reshape610, get_indices_146), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv755 = R.call_tir(cls.dequantize_group_gemm, (take47, model_layers_22_mlp_moe_gate_up_proj_q_weight3, model_layers_22_mlp_moe_gate_up_proj_q_scale3, lv754), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv419 = R.call_tir(cls.fused_split_silu_multiply, (lv755,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv756 = R.call_tir(cls.dequantize_group_gemm1, (lv419, model_layers_22_mlp_moe_down_proj_q_weight3, model_layers_22_mlp_moe_down_proj_q_scale3, lv754), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv757 = R.call_tir(cls.scatter_output, (lv756, get_indices_046), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape612 = R.call_tir(cls.reshape6, (top4_softmax_070,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape613 = R.call_tir(cls.reshape7, (lv757,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv420 = R.call_tir(cls.fused_multiply1_sum, (reshape613, reshape612), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv94 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_22_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_22_mlp_shared_expert_gate_up_proj_q_scale3, reshape610), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv421 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv94,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv422 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape610, model_layers_22_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv46_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_22_mlp_shared_expert_down_proj_q_weight3, model_layers_22_mlp_shared_expert_down_proj_q_scale3, lv421, lv422, lv420), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape614 = R.call_tir(cls.reshape14, (lv46_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv186 = R.call_tir(cls.fuse_add_norm_prefill, (reshape614, lv185, model_layers_23_input_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv187: R.Tensor((1, seq_len, 2048), dtype="float16") = lv186[1]
            rms_norm144: R.Tensor((1, seq_len, 2048), dtype="float16") = lv186[0]
            lv47 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_23_self_attn_c_attn_q_weight3, model_layers_23_self_attn_c_attn_q_scale3, rms_norm144, model_layers_23_self_attn_c_attn_bias3), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape615 = R.call_tir(cls.reshape9, (lv47,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape616 = R.call_tir(cls.reshape10, (reshape615,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv761 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(23), R.prim_value(T.float32(1.0)), reshape616), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape617 = R.call_tir(cls.reshape11, (lv761,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape618 = R.call_tir(cls.reshape12, (reshape617,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv95 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_23_self_attn_o_proj_q_weight3, model_layers_23_self_attn_o_proj_q_scale3, reshape618), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv188 = R.call_tir(cls.fuse_add_norm_prefill, (lv95, lv187, model_layers_23_post_attention_layernorm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv189: R.Tensor((1, seq_len, 2048), dtype="float16") = lv188[1]
            rms_norm145: R.Tensor((1, seq_len, 2048), dtype="float16") = lv188[0]
            reshape619 = R.call_tir(cls.reshape13, (rms_norm145,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv425 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape619, model_layers_23_mlp_gate_weight3), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv426 = R.call_tir(cls.fused_softmax_cast1, (lv425,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv763 = R.call_tir(cls.top4_softmax, (lv426,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_071: R.Tensor((seq_len, 4), dtype="float16") = lv763[0]
            top4_softmax_171: R.Tensor((seq_len, 4), dtype="int32") = lv763[1]
            lv427 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_171,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape620 = R.call_tir(cls.reshape5, (lv427,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv94_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape620, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv95_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv94_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum47: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv95_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv765 = R.call_tir(cls.get_indices, (cumsum47, top4_softmax_171), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_047: R.Tensor((seq_len * 4,), dtype="int32") = lv765[0]
            get_indices_147: R.Tensor((seq_len * 4,), dtype="int32") = lv765[1]
            lv766 = R.call_tir(cls.get_expert_instance_indptr, (cumsum47,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take48 = R.call_tir(cls.take, (reshape619, get_indices_147), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv767 = R.call_tir(cls.dequantize_group_gemm, (take48, model_layers_23_mlp_moe_gate_up_proj_q_weight3, model_layers_23_mlp_moe_gate_up_proj_q_scale3, lv766), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv428 = R.call_tir(cls.fused_split_silu_multiply, (lv767,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv768 = R.call_tir(cls.dequantize_group_gemm1, (lv428, model_layers_23_mlp_moe_down_proj_q_weight3, model_layers_23_mlp_moe_down_proj_q_scale3, lv766), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv769 = R.call_tir(cls.scatter_output, (lv768, get_indices_047), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape621 = R.call_tir(cls.reshape6, (top4_softmax_071,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape622 = R.call_tir(cls.reshape7, (lv769,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv429 = R.call_tir(cls.fused_multiply1_sum, (reshape622, reshape621), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv96_1 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_23_mlp_shared_expert_gate_up_proj_q_weight3, model_layers_23_mlp_shared_expert_gate_up_proj_q_scale3, reshape619), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv430 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv96_1,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv431 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape619, model_layers_23_mlp_shared_expert_gate_weight3), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv47_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_23_mlp_shared_expert_down_proj_q_weight3, model_layers_23_mlp_shared_expert_down_proj_q_scale3, lv430, lv431, lv429), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape623 = R.call_tir(cls.reshape14, (lv47_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv190 = R.call_tir(cls.fuse_add_norm_prefill, (reshape623, lv189, model_norm_weight3), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            rms_norm146: R.Tensor((1, seq_len, 2048), dtype="float16") = lv190[0]
            take49 = R.call_tir(cls.take1, (rms_norm146, logit_positions), out_sinfo=R.Tensor((1, batch_size, 2048), dtype="float16"))
            lv97_1 = R.call_tir(cls.fused_dequantize_fused_NT_matmul9_cast3, (lm_head_q_weight3, lm_head_q_scale3, take49), out_sinfo=R.Tensor((1, batch_size, 151936), dtype="float32"))
            gv3: R.Tuple(R.Tensor((1, batch_size, 151936), dtype="float32"), R.Object) = lv97_1, paged_kv_cache
            R.output(gv3)
        return gv3

    @R.function
    def batch_verify(input_embeds: R.Tensor((1, "seq_len", 2048), dtype="float16"), paged_kv_cache: R.Object, packed_params: R.Tuple(R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((6144, 256), dtype="uint32"), R.Tensor((6144, 64), dtype="float16"), R.Tensor((6144,), dtype="float16"), R.Tensor((2048, 256), dtype="uint32"), R.Tensor((2048, 64), dtype="float16"), R.Tensor((11264, 256), dtype="uint32"), R.Tensor((11264, 64), dtype="float16"), R.Tensor((2048, 704), dtype="uint32"), R.Tensor((2048, 176), dtype="float16"), R.Tensor((1, 2048), dtype="float16"), R.Tensor((60, 2048), dtype="float16"), R.Tensor((60, 2816, 256), dtype="uint32"), R.Tensor((60, 2816, 64), dtype="float16"), R.Tensor((60, 2048, 176), dtype="uint32"), R.Tensor((60, 2048, 44), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((2048,), dtype="float16"), R.Tensor((151936, 256), dtype="uint32"), R.Tensor((151936, 64), dtype="float16"))) -> R.Tuple(R.Tensor((1, "seq_len", 151936), dtype="float32"), R.Object):
        seq_len = T.int64()
        R.func_attr({"num_input": 2, "pipeline_parallel_stages": 1, "relax.memory_plan_dynamic_func_output": True, "relax.rewrite_cuda_graph.capture_symbolic_vars": ["batch_size", "seq_len"], "tir_non_negative_var": ["vocab_size"], "tir_var_upper_bound": {"batch_size": 80, "seq_len": 32768, "total_seq_len": 32768}})
        cls = Module
        with R.dataflow():
            model_layers_0_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[2]
            model_layers_0_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[3]
            model_layers_0_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[4]
            model_layers_0_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[5]
            model_layers_0_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[6]
            model_layers_0_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[7]
            model_layers_0_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[8]
            model_layers_0_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[9]
            model_layers_0_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[10]
            model_layers_0_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[11]
            model_layers_0_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[12]
            model_layers_0_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[13]
            model_layers_0_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[14]
            model_layers_0_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[15]
            model_layers_0_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[16]
            model_layers_0_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[17]
            model_layers_0_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[18]
            model_layers_1_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[19]
            model_layers_1_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[20]
            model_layers_1_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[21]
            model_layers_1_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[22]
            model_layers_1_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[23]
            model_layers_1_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[24]
            model_layers_1_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[25]
            model_layers_1_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[26]
            model_layers_1_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[27]
            model_layers_1_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[28]
            model_layers_1_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[29]
            model_layers_1_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[30]
            model_layers_1_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[31]
            model_layers_1_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[32]
            model_layers_1_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[33]
            model_layers_1_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[34]
            model_layers_1_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[35]
            model_layers_2_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[36]
            model_layers_2_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[37]
            model_layers_2_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[38]
            model_layers_2_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[39]
            model_layers_2_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[40]
            model_layers_2_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[41]
            model_layers_2_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[42]
            model_layers_2_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[43]
            model_layers_2_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[44]
            model_layers_2_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[45]
            model_layers_2_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[46]
            model_layers_2_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[47]
            model_layers_2_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[48]
            model_layers_2_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[49]
            model_layers_2_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[50]
            model_layers_2_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[51]
            model_layers_2_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[52]
            model_layers_3_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[53]
            model_layers_3_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[54]
            model_layers_3_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[55]
            model_layers_3_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[56]
            model_layers_3_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[57]
            model_layers_3_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[58]
            model_layers_3_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[59]
            model_layers_3_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[60]
            model_layers_3_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[61]
            model_layers_3_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[62]
            model_layers_3_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[63]
            model_layers_3_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[64]
            model_layers_3_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[65]
            model_layers_3_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[66]
            model_layers_3_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[67]
            model_layers_3_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[68]
            model_layers_3_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[69]
            model_layers_4_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[70]
            model_layers_4_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[71]
            model_layers_4_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[72]
            model_layers_4_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[73]
            model_layers_4_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[74]
            model_layers_4_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[75]
            model_layers_4_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[76]
            model_layers_4_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[77]
            model_layers_4_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[78]
            model_layers_4_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[79]
            model_layers_4_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[80]
            model_layers_4_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[81]
            model_layers_4_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[82]
            model_layers_4_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[83]
            model_layers_4_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[84]
            model_layers_4_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[85]
            model_layers_4_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[86]
            model_layers_5_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[87]
            model_layers_5_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[88]
            model_layers_5_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[89]
            model_layers_5_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[90]
            model_layers_5_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[91]
            model_layers_5_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[92]
            model_layers_5_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[93]
            model_layers_5_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[94]
            model_layers_5_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[95]
            model_layers_5_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[96]
            model_layers_5_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[97]
            model_layers_5_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[98]
            model_layers_5_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[99]
            model_layers_5_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[100]
            model_layers_5_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[101]
            model_layers_5_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[102]
            model_layers_5_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[103]
            model_layers_6_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[104]
            model_layers_6_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[105]
            model_layers_6_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[106]
            model_layers_6_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[107]
            model_layers_6_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[108]
            model_layers_6_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[109]
            model_layers_6_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[110]
            model_layers_6_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[111]
            model_layers_6_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[112]
            model_layers_6_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[113]
            model_layers_6_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[114]
            model_layers_6_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[115]
            model_layers_6_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[116]
            model_layers_6_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[117]
            model_layers_6_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[118]
            model_layers_6_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[119]
            model_layers_6_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[120]
            model_layers_7_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[121]
            model_layers_7_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[122]
            model_layers_7_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[123]
            model_layers_7_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[124]
            model_layers_7_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[125]
            model_layers_7_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[126]
            model_layers_7_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[127]
            model_layers_7_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[128]
            model_layers_7_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[129]
            model_layers_7_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[130]
            model_layers_7_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[131]
            model_layers_7_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[132]
            model_layers_7_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[133]
            model_layers_7_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[134]
            model_layers_7_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[135]
            model_layers_7_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[136]
            model_layers_7_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[137]
            model_layers_8_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[138]
            model_layers_8_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[139]
            model_layers_8_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[140]
            model_layers_8_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[141]
            model_layers_8_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[142]
            model_layers_8_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[143]
            model_layers_8_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[144]
            model_layers_8_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[145]
            model_layers_8_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[146]
            model_layers_8_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[147]
            model_layers_8_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[148]
            model_layers_8_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[149]
            model_layers_8_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[150]
            model_layers_8_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[151]
            model_layers_8_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[152]
            model_layers_8_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[153]
            model_layers_8_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[154]
            model_layers_9_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[155]
            model_layers_9_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[156]
            model_layers_9_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[157]
            model_layers_9_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[158]
            model_layers_9_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[159]
            model_layers_9_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[160]
            model_layers_9_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[161]
            model_layers_9_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[162]
            model_layers_9_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[163]
            model_layers_9_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[164]
            model_layers_9_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[165]
            model_layers_9_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[166]
            model_layers_9_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[167]
            model_layers_9_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[168]
            model_layers_9_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[169]
            model_layers_9_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[170]
            model_layers_9_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[171]
            model_layers_10_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[172]
            model_layers_10_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[173]
            model_layers_10_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[174]
            model_layers_10_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[175]
            model_layers_10_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[176]
            model_layers_10_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[177]
            model_layers_10_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[178]
            model_layers_10_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[179]
            model_layers_10_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[180]
            model_layers_10_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[181]
            model_layers_10_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[182]
            model_layers_10_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[183]
            model_layers_10_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[184]
            model_layers_10_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[185]
            model_layers_10_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[186]
            model_layers_10_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[187]
            model_layers_10_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[188]
            model_layers_11_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[189]
            model_layers_11_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[190]
            model_layers_11_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[191]
            model_layers_11_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[192]
            model_layers_11_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[193]
            model_layers_11_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[194]
            model_layers_11_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[195]
            model_layers_11_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[196]
            model_layers_11_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[197]
            model_layers_11_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[198]
            model_layers_11_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[199]
            model_layers_11_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[200]
            model_layers_11_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[201]
            model_layers_11_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[202]
            model_layers_11_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[203]
            model_layers_11_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[204]
            model_layers_11_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[205]
            model_layers_12_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[206]
            model_layers_12_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[207]
            model_layers_12_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[208]
            model_layers_12_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[209]
            model_layers_12_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[210]
            model_layers_12_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[211]
            model_layers_12_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[212]
            model_layers_12_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[213]
            model_layers_12_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[214]
            model_layers_12_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[215]
            model_layers_12_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[216]
            model_layers_12_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[217]
            model_layers_12_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[218]
            model_layers_12_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[219]
            model_layers_12_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[220]
            model_layers_12_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[221]
            model_layers_12_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[222]
            model_layers_13_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[223]
            model_layers_13_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[224]
            model_layers_13_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[225]
            model_layers_13_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[226]
            model_layers_13_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[227]
            model_layers_13_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[228]
            model_layers_13_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[229]
            model_layers_13_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[230]
            model_layers_13_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[231]
            model_layers_13_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[232]
            model_layers_13_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[233]
            model_layers_13_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[234]
            model_layers_13_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[235]
            model_layers_13_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[236]
            model_layers_13_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[237]
            model_layers_13_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[238]
            model_layers_13_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[239]
            model_layers_14_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[240]
            model_layers_14_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[241]
            model_layers_14_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[242]
            model_layers_14_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[243]
            model_layers_14_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[244]
            model_layers_14_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[245]
            model_layers_14_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[246]
            model_layers_14_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[247]
            model_layers_14_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[248]
            model_layers_14_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[249]
            model_layers_14_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[250]
            model_layers_14_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[251]
            model_layers_14_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[252]
            model_layers_14_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[253]
            model_layers_14_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[254]
            model_layers_14_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[255]
            model_layers_14_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[256]
            model_layers_15_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[257]
            model_layers_15_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[258]
            model_layers_15_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[259]
            model_layers_15_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[260]
            model_layers_15_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[261]
            model_layers_15_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[262]
            model_layers_15_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[263]
            model_layers_15_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[264]
            model_layers_15_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[265]
            model_layers_15_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[266]
            model_layers_15_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[267]
            model_layers_15_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[268]
            model_layers_15_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[269]
            model_layers_15_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[270]
            model_layers_15_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[271]
            model_layers_15_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[272]
            model_layers_15_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[273]
            model_layers_16_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[274]
            model_layers_16_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[275]
            model_layers_16_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[276]
            model_layers_16_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[277]
            model_layers_16_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[278]
            model_layers_16_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[279]
            model_layers_16_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[280]
            model_layers_16_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[281]
            model_layers_16_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[282]
            model_layers_16_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[283]
            model_layers_16_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[284]
            model_layers_16_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[285]
            model_layers_16_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[286]
            model_layers_16_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[287]
            model_layers_16_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[288]
            model_layers_16_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[289]
            model_layers_16_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[290]
            model_layers_17_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[291]
            model_layers_17_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[292]
            model_layers_17_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[293]
            model_layers_17_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[294]
            model_layers_17_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[295]
            model_layers_17_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[296]
            model_layers_17_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[297]
            model_layers_17_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[298]
            model_layers_17_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[299]
            model_layers_17_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[300]
            model_layers_17_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[301]
            model_layers_17_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[302]
            model_layers_17_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[303]
            model_layers_17_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[304]
            model_layers_17_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[305]
            model_layers_17_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[306]
            model_layers_17_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[307]
            model_layers_18_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[308]
            model_layers_18_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[309]
            model_layers_18_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[310]
            model_layers_18_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[311]
            model_layers_18_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[312]
            model_layers_18_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[313]
            model_layers_18_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[314]
            model_layers_18_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[315]
            model_layers_18_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[316]
            model_layers_18_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[317]
            model_layers_18_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[318]
            model_layers_18_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[319]
            model_layers_18_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[320]
            model_layers_18_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[321]
            model_layers_18_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[322]
            model_layers_18_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[323]
            model_layers_18_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[324]
            model_layers_19_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[325]
            model_layers_19_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[326]
            model_layers_19_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[327]
            model_layers_19_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[328]
            model_layers_19_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[329]
            model_layers_19_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[330]
            model_layers_19_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[331]
            model_layers_19_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[332]
            model_layers_19_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[333]
            model_layers_19_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[334]
            model_layers_19_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[335]
            model_layers_19_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[336]
            model_layers_19_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[337]
            model_layers_19_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[338]
            model_layers_19_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[339]
            model_layers_19_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[340]
            model_layers_19_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[341]
            model_layers_20_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[342]
            model_layers_20_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[343]
            model_layers_20_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[344]
            model_layers_20_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[345]
            model_layers_20_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[346]
            model_layers_20_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[347]
            model_layers_20_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[348]
            model_layers_20_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[349]
            model_layers_20_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[350]
            model_layers_20_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[351]
            model_layers_20_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[352]
            model_layers_20_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[353]
            model_layers_20_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[354]
            model_layers_20_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[355]
            model_layers_20_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[356]
            model_layers_20_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[357]
            model_layers_20_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[358]
            model_layers_21_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[359]
            model_layers_21_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[360]
            model_layers_21_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[361]
            model_layers_21_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[362]
            model_layers_21_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[363]
            model_layers_21_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[364]
            model_layers_21_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[365]
            model_layers_21_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[366]
            model_layers_21_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[367]
            model_layers_21_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[368]
            model_layers_21_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[369]
            model_layers_21_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[370]
            model_layers_21_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[371]
            model_layers_21_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[372]
            model_layers_21_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[373]
            model_layers_21_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[374]
            model_layers_21_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[375]
            model_layers_22_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[376]
            model_layers_22_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[377]
            model_layers_22_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[378]
            model_layers_22_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[379]
            model_layers_22_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[380]
            model_layers_22_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[381]
            model_layers_22_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[382]
            model_layers_22_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[383]
            model_layers_22_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[384]
            model_layers_22_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[385]
            model_layers_22_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[386]
            model_layers_22_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[387]
            model_layers_22_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[388]
            model_layers_22_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[389]
            model_layers_22_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[390]
            model_layers_22_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[391]
            model_layers_22_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[392]
            model_layers_23_self_attn_c_attn_q_weight5: R.Tensor((6144, 256), dtype="uint32") = packed_params[393]
            model_layers_23_self_attn_c_attn_q_scale5: R.Tensor((6144, 64), dtype="float16") = packed_params[394]
            model_layers_23_self_attn_c_attn_bias5: R.Tensor((6144,), dtype="float16") = packed_params[395]
            model_layers_23_self_attn_o_proj_q_weight5: R.Tensor((2048, 256), dtype="uint32") = packed_params[396]
            model_layers_23_self_attn_o_proj_q_scale5: R.Tensor((2048, 64), dtype="float16") = packed_params[397]
            model_layers_23_mlp_shared_expert_gate_up_proj_q_weight5: R.Tensor((11264, 256), dtype="uint32") = packed_params[398]
            model_layers_23_mlp_shared_expert_gate_up_proj_q_scale5: R.Tensor((11264, 64), dtype="float16") = packed_params[399]
            model_layers_23_mlp_shared_expert_down_proj_q_weight5: R.Tensor((2048, 704), dtype="uint32") = packed_params[400]
            model_layers_23_mlp_shared_expert_down_proj_q_scale5: R.Tensor((2048, 176), dtype="float16") = packed_params[401]
            model_layers_23_mlp_shared_expert_gate_weight5: R.Tensor((1, 2048), dtype="float16") = packed_params[402]
            model_layers_23_mlp_gate_weight5: R.Tensor((60, 2048), dtype="float16") = packed_params[403]
            model_layers_23_mlp_moe_gate_up_proj_q_weight5: R.Tensor((60, 2816, 256), dtype="uint32") = packed_params[404]
            model_layers_23_mlp_moe_gate_up_proj_q_scale5: R.Tensor((60, 2816, 64), dtype="float16") = packed_params[405]
            model_layers_23_mlp_moe_down_proj_q_weight5: R.Tensor((60, 2048, 176), dtype="uint32") = packed_params[406]
            model_layers_23_mlp_moe_down_proj_q_scale5: R.Tensor((60, 2048, 44), dtype="float16") = packed_params[407]
            model_layers_23_input_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[408]
            model_layers_23_post_attention_layernorm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[409]
            model_norm_weight5: R.Tensor((2048,), dtype="float16") = packed_params[410]
            lm_head_q_weight5: R.Tensor((151936, 256), dtype="uint32") = packed_params[411]
            lm_head_q_scale5: R.Tensor((151936, 64), dtype="float16") = packed_params[412]
            rms_norm196 = R.call_tir(cls.rms_norm1, (input_embeds, model_layers_0_input_layernorm_weight5), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv48 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_0_self_attn_c_attn_q_weight5, model_layers_0_self_attn_c_attn_q_scale5, rms_norm196, model_layers_0_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape840 = R.call_tir(cls.reshape9, (lv48,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape841 = R.call_tir(cls.reshape10, (reshape840,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv1063 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(0), R.prim_value(T.float32(1.0)), reshape841), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape842 = R.call_tir(cls.reshape11, (lv1063,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape843 = R.call_tir(cls.reshape12, (reshape842,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv98 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_0_self_attn_o_proj_q_weight5, model_layers_0_self_attn_o_proj_q_scale5, reshape843), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv192 = R.call_tir(cls.fuse_add_norm_prefill, (lv98, input_embeds, model_layers_0_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv193: R.Tensor((1, seq_len, 2048), dtype="float16") = lv192[1]
            rms_norm197: R.Tensor((1, seq_len, 2048), dtype="float16") = lv192[0]
            reshape844 = R.call_tir(cls.reshape13, (rms_norm197,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv435 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape844, model_layers_0_mlp_gate_weight5), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv436 = R.call_tir(cls.fused_softmax_cast1, (lv435,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv1065 = R.call_tir(cls.top4_softmax, (lv436,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_096: R.Tensor((seq_len, 4), dtype="float16") = lv1065[0]
            top4_softmax_196: R.Tensor((seq_len, 4), dtype="int32") = lv1065[1]
            lv437 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_196,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape845 = R.call_tir(cls.reshape5, (lv437,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv96: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape845, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv97 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv96,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum72: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv97, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv1067 = R.call_tir(cls.get_indices, (cumsum72, top4_softmax_196), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_072: R.Tensor((seq_len * 4,), dtype="int32") = lv1067[0]
            get_indices_172: R.Tensor((seq_len * 4,), dtype="int32") = lv1067[1]
            lv1068 = R.call_tir(cls.get_expert_instance_indptr, (cumsum72,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take74 = R.call_tir(cls.take, (reshape844, get_indices_172), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv1069 = R.call_tir(cls.dequantize_group_gemm, (take74, model_layers_0_mlp_moe_gate_up_proj_q_weight5, model_layers_0_mlp_moe_gate_up_proj_q_scale5, lv1068), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv438 = R.call_tir(cls.fused_split_silu_multiply, (lv1069,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv1070 = R.call_tir(cls.dequantize_group_gemm1, (lv438, model_layers_0_mlp_moe_down_proj_q_weight5, model_layers_0_mlp_moe_down_proj_q_scale5, lv1068), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv1071 = R.call_tir(cls.scatter_output, (lv1070, get_indices_072), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape846 = R.call_tir(cls.reshape6, (top4_softmax_096,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape847 = R.call_tir(cls.reshape7, (lv1071,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv439 = R.call_tir(cls.fused_multiply1_sum, (reshape847, reshape846), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv99 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_0_mlp_shared_expert_gate_up_proj_q_weight5, model_layers_0_mlp_shared_expert_gate_up_proj_q_scale5, reshape844), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv440 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv99,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv441 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape844, model_layers_0_mlp_shared_expert_gate_weight5), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv48_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_0_mlp_shared_expert_down_proj_q_weight5, model_layers_0_mlp_shared_expert_down_proj_q_scale5, lv440, lv441, lv439), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape848 = R.call_tir(cls.reshape14, (lv48_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv194 = R.call_tir(cls.fuse_add_norm_prefill, (reshape848, lv193, model_layers_1_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv195: R.Tensor((1, seq_len, 2048), dtype="float16") = lv194[1]
            rms_norm198: R.Tensor((1, seq_len, 2048), dtype="float16") = lv194[0]
            lv49 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_1_self_attn_c_attn_q_weight5, model_layers_1_self_attn_c_attn_q_scale5, rms_norm198, model_layers_1_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape849 = R.call_tir(cls.reshape9, (lv49,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape850 = R.call_tir(cls.reshape10, (reshape849,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv1075 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(1), R.prim_value(T.float32(1.0)), reshape850), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape851 = R.call_tir(cls.reshape11, (lv1075,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape852 = R.call_tir(cls.reshape12, (reshape851,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv100 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_1_self_attn_o_proj_q_weight5, model_layers_1_self_attn_o_proj_q_scale5, reshape852), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv196 = R.call_tir(cls.fuse_add_norm_prefill, (lv100, lv195, model_layers_1_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv197: R.Tensor((1, seq_len, 2048), dtype="float16") = lv196[1]
            rms_norm199: R.Tensor((1, seq_len, 2048), dtype="float16") = lv196[0]
            reshape853 = R.call_tir(cls.reshape13, (rms_norm199,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv444 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape853, model_layers_1_mlp_gate_weight5), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv445 = R.call_tir(cls.fused_softmax_cast1, (lv444,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv1077 = R.call_tir(cls.top4_softmax, (lv445,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_097: R.Tensor((seq_len, 4), dtype="float16") = lv1077[0]
            top4_softmax_197: R.Tensor((seq_len, 4), dtype="int32") = lv1077[1]
            lv446 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_197,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape854 = R.call_tir(cls.reshape5, (lv446,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv98_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape854, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv99_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv98_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum73: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv99_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv1079 = R.call_tir(cls.get_indices, (cumsum73, top4_softmax_197), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_073: R.Tensor((seq_len * 4,), dtype="int32") = lv1079[0]
            get_indices_173: R.Tensor((seq_len * 4,), dtype="int32") = lv1079[1]
            lv1080 = R.call_tir(cls.get_expert_instance_indptr, (cumsum73,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take75 = R.call_tir(cls.take, (reshape853, get_indices_173), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv1081 = R.call_tir(cls.dequantize_group_gemm, (take75, model_layers_1_mlp_moe_gate_up_proj_q_weight5, model_layers_1_mlp_moe_gate_up_proj_q_scale5, lv1080), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv447 = R.call_tir(cls.fused_split_silu_multiply, (lv1081,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv1082 = R.call_tir(cls.dequantize_group_gemm1, (lv447, model_layers_1_mlp_moe_down_proj_q_weight5, model_layers_1_mlp_moe_down_proj_q_scale5, lv1080), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv1083 = R.call_tir(cls.scatter_output, (lv1082, get_indices_073), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape855 = R.call_tir(cls.reshape6, (top4_softmax_097,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape856 = R.call_tir(cls.reshape7, (lv1083,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv448 = R.call_tir(cls.fused_multiply1_sum, (reshape856, reshape855), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv101 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_1_mlp_shared_expert_gate_up_proj_q_weight5, model_layers_1_mlp_shared_expert_gate_up_proj_q_scale5, reshape853), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv449 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv101,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv450 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape853, model_layers_1_mlp_shared_expert_gate_weight5), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv49_1 = R.call_tir(cls.fused_dequantize4_fused_NT_matmul4_multiply3_add1, (model_layers_1_mlp_shared_expert_down_proj_q_weight5, model_layers_1_mlp_shared_expert_down_proj_q_scale5, lv449, lv450, lv448), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            reshape857 = R.call_tir(cls.reshape14, (lv49_1,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv198 = R.call_tir(cls.fuse_add_norm_prefill, (reshape857, lv197, model_layers_2_input_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv199: R.Tensor((1, seq_len, 2048), dtype="float16") = lv198[1]
            rms_norm200: R.Tensor((1, seq_len, 2048), dtype="float16") = lv198[0]
            lv50 = R.call_tir(cls.fused_dequantize1_fused_NT_matmul7_add2, (model_layers_2_self_attn_c_attn_q_weight5, model_layers_2_self_attn_c_attn_q_scale5, rms_norm200, model_layers_2_self_attn_c_attn_bias5), out_sinfo=R.Tensor((1, seq_len, 6144), dtype="float16"))
            reshape858 = R.call_tir(cls.reshape9, (lv50,), out_sinfo=R.Tensor((1, seq_len, 48, 128), dtype="float16"))
            reshape859 = R.call_tir(cls.reshape10, (reshape858,), out_sinfo=R.Tensor((seq_len, 48, 128), dtype="float16"))
            lv1087 = R.call_dps_packed("vm.builtin.attention_kv_cache_attention_with_fused_qkv", (paged_kv_cache, R.prim_value(2), R.prim_value(T.float32(1.0)), reshape859), out_sinfo=R.Tensor((seq_len, 16, 128), dtype="float16"))
            reshape860 = R.call_tir(cls.reshape11, (lv1087,), out_sinfo=R.Tensor((1, seq_len, 16, 128), dtype="float16"))
            reshape861 = R.call_tir(cls.reshape12, (reshape860,), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv102 = R.call_tir(cls.fused_dequantize2_NT_matmul8, (model_layers_2_self_attn_o_proj_q_weight5, model_layers_2_self_attn_o_proj_q_scale5, reshape861), out_sinfo=R.Tensor((1, seq_len, 2048), dtype="float16"))
            lv200 = R.call_tir(cls.fuse_add_norm_prefill, (lv102, lv199, model_layers_2_post_attention_layernorm_weight5), out_sinfo=[R.Tensor((1, seq_len, 2048), dtype="float16"), R.Tensor((1, seq_len, 2048), dtype="float16")])
            lv201: R.Tensor((1, seq_len, 2048), dtype="float16") = lv200[1]
            rms_norm201: R.Tensor((1, seq_len, 2048), dtype="float16") = lv200[0]
            reshape862 = R.call_tir(cls.reshape13, (rms_norm201,), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv453 = R.call_tir(cls.fused_NT_matmul2_cast, (reshape862, model_layers_2_mlp_gate_weight5), out_sinfo=R.Tensor((seq_len, 60), dtype="float32"))
            lv454 = R.call_tir(cls.fused_softmax_cast1, (lv453,), out_sinfo=R.Tensor((seq_len, 60), dtype="float16"))
            lv1089 = R.call_tir(cls.top4_softmax, (lv454,), out_sinfo=[R.Tensor((seq_len, 4), dtype="float16"), R.Tensor((seq_len, 4), dtype="int32")])
            top4_softmax_098: R.Tensor((seq_len, 4), dtype="float16") = lv1089[0]
            top4_softmax_198: R.Tensor((seq_len, 4), dtype="int32") = lv1089[1]
            lv455 = R.call_tir(cls.fused_expert_mask_transpose, (top4_softmax_198,), out_sinfo=R.Tensor((60, seq_len), dtype="int32"))
            reshape863 = R.call_tir(cls.reshape5, (lv455,), out_sinfo=R.Tensor((seq_len * 60,), dtype="int32"))
            lv100_1: R.Tensor((1, seq_len * 60), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", reshape863, R.shape([1, seq_len * 60]), sinfo_args=(R.Tensor((1, seq_len * 60), dtype="int32"),))
            lv101_1 = R.call_tir(cls.gpu_2d_continuous_cumsum, (lv100_1,), out_sinfo=R.Tensor((1, seq_len * 60), dtype="int32"))
            cumsum74: R.Tensor((seq_len * 60,), dtype="int32") = R.call_pure_packed("vm.builtin.reshape", lv101_1, R.shape([seq_len * 60]), sinfo_args=(R.Tensor((seq_len * 60,), dtype="int32"),))
            lv1091 = R.call_tir(cls.get_indices, (cumsum74, top4_softmax_198), out_sinfo=[R.Tensor((seq_len * 4,), dtype="int32"), R.Tensor((seq_len * 4,), dtype="int32")])
            get_indices_074: R.Tensor((seq_len * 4,), dtype="int32") = lv1091[0]
            get_indices_174: R.Tensor((seq_len * 4,), dtype="int32") = lv1091[1]
            lv1092 = R.call_tir(cls.get_expert_instance_indptr, (cumsum74,), out_sinfo=R.Tensor((61,), dtype="int32"), tir_vars=R.shape([seq_len]))
            take76 = R.call_tir(cls.take, (reshape862, get_indices_174), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv1093 = R.call_tir(cls.dequantize_group_gemm, (take76, model_layers_2_mlp_moe_gate_up_proj_q_weight5, model_layers_2_mlp_moe_gate_up_proj_q_scale5, lv1092), out_sinfo=R.Tensor((seq_len * 4, 2816), dtype="float16"))
            lv456 = R.call_tir(cls.fused_split_silu_multiply, (lv1093,), out_sinfo=R.Tensor((seq_len * 4, 1408), dtype="float16"), tir_vars=R.shape([seq_len]))
            lv1094 = R.call_tir(cls.dequantize_group_gemm1, (lv456, model_layers_2_mlp_moe_down_proj_q_weight5, model_layers_2_mlp_moe_down_proj_q_scale5, lv1092), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            lv1095 = R.call_tir(cls.scatter_output, (lv1094, get_indices_074), out_sinfo=R.Tensor((seq_len * 4, 2048), dtype="float16"))
            reshape864 = R.call_tir(cls.reshape6, (top4_softmax_098,), out_sinfo=R.Tensor((seq_len, 4, 1), dtype="float16"))
            reshape865 = R.call_tir(cls.reshape7, (lv1095,), out_sinfo=R.Tensor((seq_len, 4, 2048), dtype="float16"))
            lv457 = R.call_tir(cls.fused_multiply1_sum, (reshape865, reshape864), out_sinfo=R.Tensor((seq_len, 2048), dtype="float16"))
            lv103 = R.call_tir(cls.fused_dequantize3_NT_matmul3, (model_layers_2_mlp_shared_expert_gate_up_proj_q_weight5, model_layers_2_mlp_shared_expert_gate_up_proj_q_scale5, reshape862), out_sinfo=R.Tensor((seq_len, 11264), dtype="float16"))
            lv458 = R.call_tir(cls.fused_split1_silu1_multiply2, (lv103,), out_sinfo=R.Tensor((seq_len, 5632), dtype="float16"))
            lv459 = R.call_tir(cls.fused_NT_matmul5_tir_sigmoid, (reshape862, model_layers_2_mlp_shared_expert_gate_weight5), out_sinfo=R.Tensor((seq_len, 1), dtype="float16"))
            lv50_1 = R.call_tir(cls.fused_dequan